fix: add currying to question_answering_non_parallel (#602)
…l to avoid additional params <!-- .github/pull_request_template.md --> Introduces lambda currying in question answering non parallel function to avoid unnecessary params ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Streamlined the question-answering process for cleaner, more efficient query handling. - Updated the handling of parameters in the answer generation process, allowing for a more dynamic integration of context. - Simplified test setups by reducing the number of parameters involved in the mock answer resolver. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
cade574bbf
commit
3e93dbe264
3 changed files with 7 additions and 7 deletions
|
|
@ -2,7 +2,7 @@ import cognee
|
||||||
from typing import List, Dict, Callable, Awaitable
|
from typing import List, Dict, Callable, Awaitable
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
|
|
||||||
question_answering_engine_options: Dict[str, Callable[[str], Awaitable[List[str]]]] = {
|
question_answering_engine_options: Dict[str, Callable[[str, str], Awaitable[List[str]]]] = {
|
||||||
"cognee_graph_completion": lambda query, system_prompt_path: cognee.search(
|
"cognee_graph_completion": lambda query, system_prompt_path: cognee.search(
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
query_text=query,
|
query_text=query,
|
||||||
|
|
@ -24,14 +24,13 @@ class AnswerGeneratorExecutor:
|
||||||
self,
|
self,
|
||||||
questions: List[Dict[str, str]],
|
questions: List[Dict[str, str]],
|
||||||
answer_resolver: Callable[[str], Awaitable[List[str]]],
|
answer_resolver: Callable[[str], Awaitable[List[str]]],
|
||||||
system_prompt: str = "answer_simple_question.txt",
|
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
answers = []
|
answers = []
|
||||||
for instance in questions:
|
for instance in questions:
|
||||||
query_text = instance["question"]
|
query_text = instance["question"]
|
||||||
correct_answer = instance["answer"]
|
correct_answer = instance["answer"]
|
||||||
|
|
||||||
search_results = await answer_resolver(query_text, system_prompt)
|
search_results = await answer_resolver(query_text)
|
||||||
|
|
||||||
answers.append(
|
answers.append(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -48,8 +48,9 @@ async def run_question_answering(
|
||||||
answer_generator = AnswerGeneratorExecutor()
|
answer_generator = AnswerGeneratorExecutor()
|
||||||
answers = await answer_generator.question_answering_non_parallel(
|
answers = await answer_generator.question_answering_non_parallel(
|
||||||
questions=questions,
|
questions=questions,
|
||||||
answer_resolver=question_answering_engine_options[params["qa_engine"]],
|
answer_resolver=lambda query: question_answering_engine_options[params["qa_engine"]](
|
||||||
system_prompt=system_prompt,
|
query, system_prompt
|
||||||
|
),
|
||||||
)
|
)
|
||||||
with open(params["answers_path"], "w", encoding="utf-8") as f:
|
with open(params["answers_path"], "w", encoding="utf-8") as f:
|
||||||
json.dump(answers, f, ensure_ascii=False, indent=4)
|
json.dump(answers, f, ensure_ascii=False, indent=4)
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,11 @@ async def test_answer_generation():
|
||||||
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
|
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
|
||||||
|
|
||||||
mock_answer_resolver = AsyncMock()
|
mock_answer_resolver = AsyncMock()
|
||||||
mock_answer_resolver.side_effect = lambda query, system_prompt: ["mock_answer"]
|
mock_answer_resolver.side_effect = lambda query: ["mock_answer"]
|
||||||
|
|
||||||
answer_generator = AnswerGeneratorExecutor()
|
answer_generator = AnswerGeneratorExecutor()
|
||||||
answers = await answer_generator.question_answering_non_parallel(
|
answers = await answer_generator.question_answering_non_parallel(
|
||||||
questions=qa_pairs, answer_resolver=mock_answer_resolver, system_prompt="test.txt"
|
questions=qa_pairs, answer_resolver=mock_answer_resolver
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(answers) == len(qa_pairs)
|
assert len(answers) == len(qa_pairs)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue