fix: fixes hotpot and twowiki tests (that are using url to download dataset)
This commit is contained in:
parent
499c717f85
commit
fc89f71e7c
2 changed files with 58 additions and 4 deletions
|
|
@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
|
||||||
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
|
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
MOCK_HOTPOT_CORPUS = [
|
||||||
|
{
|
||||||
|
"_id": "1",
|
||||||
|
"question": "Next to which country is Germany located?",
|
||||||
|
"answer": "Netherlands",
|
||||||
|
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
||||||
|
"level": "easy",
|
||||||
|
"type": "comparison",
|
||||||
|
"context": [
|
||||||
|
["Germany", ["Germany is in Europe."]],
|
||||||
|
["Netherlands", ["The Netherlands borders Germany."]],
|
||||||
|
],
|
||||||
|
"supporting_facts": [["Netherlands", 0]],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
ADAPTER_CLASSES = [
|
ADAPTER_CLASSES = [
|
||||||
HotpotQAAdapter,
|
HotpotQAAdapter,
|
||||||
|
|
@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
|
||||||
adapter = AdapterClass()
|
adapter = AdapterClass()
|
||||||
result = adapter.load_corpus()
|
result = adapter.load_corpus()
|
||||||
|
|
||||||
|
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
||||||
|
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||||
|
adapter = AdapterClass()
|
||||||
|
result = adapter.load_corpus()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
adapter = AdapterClass()
|
adapter = AdapterClass()
|
||||||
result = adapter.load_corpus()
|
result = adapter.load_corpus()
|
||||||
|
|
@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
|
||||||
):
|
):
|
||||||
adapter = AdapterClass()
|
adapter = AdapterClass()
|
||||||
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||||
|
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
||||||
|
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||||
|
adapter = AdapterClass()
|
||||||
|
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||||
else:
|
else:
|
||||||
adapter = AdapterClass()
|
adapter = AdapterClass()
|
||||||
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,38 @@ import pytest
|
||||||
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
|
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
|
||||||
|
|
||||||
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
|
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
|
||||||
|
|
||||||
|
MOCK_HOTPOT_CORPUS = [
|
||||||
|
{
|
||||||
|
"_id": "1",
|
||||||
|
"question": "Next to which country is Germany located?",
|
||||||
|
"answer": "Netherlands",
|
||||||
|
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
||||||
|
"level": "easy",
|
||||||
|
"type": "comparison",
|
||||||
|
"context": [
|
||||||
|
["Germany", ["Germany is in Europe."]],
|
||||||
|
["Netherlands", ["The Netherlands borders Germany."]],
|
||||||
|
],
|
||||||
|
"supporting_facts": [["Netherlands", 0]],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("benchmark", benchmark_options)
|
@pytest.mark.parametrize("benchmark", benchmark_options)
|
||||||
def test_corpus_builder_load_corpus(benchmark):
|
def test_corpus_builder_load_corpus(benchmark):
|
||||||
limit = 2
|
limit = 2
|
||||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
||||||
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||||
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
|
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||||
|
else:
|
||||||
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
|
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||||
|
|
||||||
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
||||||
assert len(questions) <= 2, (
|
assert len(questions) <= 2, (
|
||||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||||
|
|
@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
|
||||||
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
|
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
|
||||||
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
|
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
|
||||||
limit = 2
|
limit = 2
|
||||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
||||||
questions = await corpus_builder.build_corpus(limit=limit)
|
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||||
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
|
questions = await corpus_builder.build_corpus(limit=limit)
|
||||||
|
else:
|
||||||
|
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||||
|
questions = await corpus_builder.build_corpus(limit=limit)
|
||||||
|
|
||||||
assert len(questions) <= 2, (
|
assert len(questions) <= 2, (
|
||||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue