diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index f4fc5f4a2..b4afc05b3 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -2,7 +2,7 @@ import cognee from typing import List, Dict, Callable, Awaitable from cognee.api.v1.search import SearchType -question_answering_engine_options: Dict[str, Callable[[str], Awaitable[List[str]]]] = { +question_answering_engine_options: Dict[str, Callable[[str, str], Awaitable[List[str]]]] = { "cognee_graph_completion": lambda query, system_prompt_path: cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text=query, @@ -24,14 +24,13 @@ class AnswerGeneratorExecutor: self, questions: List[Dict[str, str]], answer_resolver: Callable[[str], Awaitable[List[str]]], - system_prompt: str = "answer_simple_question.txt", ) -> List[Dict[str, str]]: answers = [] for instance in questions: query_text = instance["question"] correct_answer = instance["answer"] - search_results = await answer_resolver(query_text, system_prompt) + search_results = await answer_resolver(query_text) answers.append( { diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py index 42b31d44b..9caf71f6a 100644 --- a/cognee/eval_framework/answer_generation/run_question_answering_module.py +++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py @@ -48,8 +48,9 @@ async def run_question_answering( answer_generator = AnswerGeneratorExecutor() answers = await answer_generator.question_answering_non_parallel( questions=questions, - answer_resolver=question_answering_engine_options[params["qa_engine"]], - system_prompt=system_prompt, + answer_resolver=lambda query: question_answering_engine_options[params["qa_engine"]]( + query, system_prompt + ), ) with open(params["answers_path"], "w", encoding="utf-8") as f: json.dump(answers, f, ensure_ascii=False, indent=4) diff --git a/cognee/tests/unit/eval_framework/answer_generation_test.py b/cognee/tests/unit/eval_framework/answer_generation_test.py index 5e6ae3a02..d02ffd27d 100644 --- a/cognee/tests/unit/eval_framework/answer_generation_test.py +++ b/cognee/tests/unit/eval_framework/answer_generation_test.py @@ -12,11 +12,11 @@ async def test_answer_generation(): corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit) mock_answer_resolver = AsyncMock() - mock_answer_resolver.side_effect = lambda query, system_prompt: ["mock_answer"] + mock_answer_resolver.side_effect = lambda query: ["mock_answer"] answer_generator = AnswerGeneratorExecutor() answers = await answer_generator.question_answering_non_parallel( - questions=qa_pairs, answer_resolver=mock_answer_resolver, system_prompt="test.txt" + questions=qa_pairs, answer_resolver=mock_answer_resolver ) assert len(answers) == len(qa_pairs)