fix: add currying to question_answering_non_parallel (#602)
…l to avoid additional params <!-- .github/pull_request_template.md --> Introduces lambda currying in question answering non parallel function to avoid unnecessary params ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Streamlined the question-answering process for cleaner, more efficient query handling. - Updated the handling of parameters in the answer generation process, allowing for a more dynamic integration of context. - Simplified test setups by reducing the number of parameters involved in the mock answer resolver. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
cade574bbf
commit
3e93dbe264
3 changed files with 7 additions and 7 deletions
|
|
@ -2,7 +2,7 @@ import cognee
|
|||
from typing import List, Dict, Callable, Awaitable
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
||||
question_answering_engine_options: Dict[str, Callable[[str], Awaitable[List[str]]]] = {
|
||||
question_answering_engine_options: Dict[str, Callable[[str, str], Awaitable[List[str]]]] = {
|
||||
"cognee_graph_completion": lambda query, system_prompt_path: cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text=query,
|
||||
|
|
@ -24,14 +24,13 @@ class AnswerGeneratorExecutor:
|
|||
self,
|
||||
questions: List[Dict[str, str]],
|
||||
answer_resolver: Callable[[str], Awaitable[List[str]]],
|
||||
system_prompt: str = "answer_simple_question.txt",
|
||||
) -> List[Dict[str, str]]:
|
||||
answers = []
|
||||
for instance in questions:
|
||||
query_text = instance["question"]
|
||||
correct_answer = instance["answer"]
|
||||
|
||||
search_results = await answer_resolver(query_text, system_prompt)
|
||||
search_results = await answer_resolver(query_text)
|
||||
|
||||
answers.append(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -48,8 +48,9 @@ async def run_question_answering(
|
|||
answer_generator = AnswerGeneratorExecutor()
|
||||
answers = await answer_generator.question_answering_non_parallel(
|
||||
questions=questions,
|
||||
answer_resolver=question_answering_engine_options[params["qa_engine"]],
|
||||
system_prompt=system_prompt,
|
||||
answer_resolver=lambda query: question_answering_engine_options[params["qa_engine"]](
|
||||
query, system_prompt
|
||||
),
|
||||
)
|
||||
with open(params["answers_path"], "w", encoding="utf-8") as f:
|
||||
json.dump(answers, f, ensure_ascii=False, indent=4)
|
||||
|
|
|
|||
|
|
@ -12,11 +12,11 @@ async def test_answer_generation():
|
|||
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
|
||||
|
||||
mock_answer_resolver = AsyncMock()
|
||||
mock_answer_resolver.side_effect = lambda query, system_prompt: ["mock_answer"]
|
||||
mock_answer_resolver.side_effect = lambda query: ["mock_answer"]
|
||||
|
||||
answer_generator = AnswerGeneratorExecutor()
|
||||
answers = await answer_generator.question_answering_non_parallel(
|
||||
questions=qa_pairs, answer_resolver=mock_answer_resolver, system_prompt="test.txt"
|
||||
questions=qa_pairs, answer_resolver=mock_answer_resolver
|
||||
)
|
||||
|
||||
assert len(answers) == len(qa_pairs)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue