From 3e93dbe264b5ac34318ccc501f5243c9a39e2be0 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 4 Mar 2025 16:09:53 +0100 Subject: [PATCH] fix: add currying to question_answering_non_parallel (#602) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …l to avoid additional params Introduces lambda currying in question answering non parallel function to avoid unnecessary params ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin ## Summary by CodeRabbit - **Refactor** - Streamlined the question-answering process for cleaner, more efficient query handling. - Updated the handling of parameters in the answer generation process, allowing for a more dynamic integration of context. - Simplified test setups by reducing the number of parameters involved in the mock answer resolver. --- .../answer_generation/answer_generation_executor.py | 5 ++--- .../answer_generation/run_question_answering_module.py | 5 +++-- cognee/tests/unit/eval_framework/answer_generation_test.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index f4fc5f4a2..b4afc05b3 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -2,7 +2,7 @@ import cognee from typing import List, Dict, Callable, Awaitable from cognee.api.v1.search import SearchType -question_answering_engine_options: Dict[str, Callable[[str], Awaitable[List[str]]]] = { +question_answering_engine_options: Dict[str, Callable[[str, str], Awaitable[List[str]]]] = { "cognee_graph_completion": lambda query, system_prompt_path: cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text=query, @@ -24,14 +24,13 @@ class AnswerGeneratorExecutor: self, questions: List[Dict[str, str]], answer_resolver: Callable[[str], Awaitable[List[str]]], - system_prompt: str = "answer_simple_question.txt", ) -> List[Dict[str, str]]: answers = [] for instance in questions: query_text = instance["question"] correct_answer = instance["answer"] - search_results = await answer_resolver(query_text, system_prompt) + search_results = await answer_resolver(query_text) answers.append( { diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py index 42b31d44b..9caf71f6a 100644 --- a/cognee/eval_framework/answer_generation/run_question_answering_module.py +++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py @@ -48,8 +48,9 @@ async def run_question_answering( answer_generator = AnswerGeneratorExecutor() answers = await answer_generator.question_answering_non_parallel( questions=questions, - answer_resolver=question_answering_engine_options[params["qa_engine"]], - system_prompt=system_prompt, + answer_resolver=lambda query: question_answering_engine_options[params["qa_engine"]]( + query, system_prompt + ), ) with open(params["answers_path"], "w", encoding="utf-8") as f: json.dump(answers, f, ensure_ascii=False, indent=4) diff --git a/cognee/tests/unit/eval_framework/answer_generation_test.py b/cognee/tests/unit/eval_framework/answer_generation_test.py index 5e6ae3a02..d02ffd27d 100644 --- a/cognee/tests/unit/eval_framework/answer_generation_test.py +++ b/cognee/tests/unit/eval_framework/answer_generation_test.py @@ -12,11 +12,11 @@ async def test_answer_generation(): corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit) mock_answer_resolver = AsyncMock() - mock_answer_resolver.side_effect = lambda query, system_prompt: ["mock_answer"] + mock_answer_resolver.side_effect = lambda query: ["mock_answer"] answer_generator = AnswerGeneratorExecutor() answers = await answer_generator.question_answering_non_parallel( - questions=qa_pairs, answer_resolver=mock_answer_resolver, system_prompt="test.txt" + questions=qa_pairs, answer_resolver=mock_answer_resolver ) assert len(answers) == len(qa_pairs)