fix: add currying to question_answering_non_parallel (#602)

…l to avoid additional params

<!-- .github/pull_request_template.md -->

Introduces lambda currying in question answering non parallel function
to avoid unnecessary params

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Refactor**
- Streamlined the question-answering process for cleaner, more efficient
query handling.
- Updated the handling of parameters in the answer generation process,
allowing for a more dynamic integration of context.
- Simplified test setups by reducing the number of parameters involved
in the mock answer resolver.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
hajdul88 2025-03-04 16:09:53 +01:00 committed by GitHub
parent cade574bbf
commit 3e93dbe264
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 7 additions and 7 deletions

View file

@ -2,7 +2,7 @@ import cognee
from typing import List, Dict, Callable, Awaitable
from cognee.api.v1.search import SearchType
question_answering_engine_options: Dict[str, Callable[[str], Awaitable[List[str]]]] = {
question_answering_engine_options: Dict[str, Callable[[str, str], Awaitable[List[str]]]] = {
"cognee_graph_completion": lambda query, system_prompt_path: cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=query,
@ -24,14 +24,13 @@ class AnswerGeneratorExecutor:
self,
questions: List[Dict[str, str]],
answer_resolver: Callable[[str], Awaitable[List[str]]],
system_prompt: str = "answer_simple_question.txt",
) -> List[Dict[str, str]]:
answers = []
for instance in questions:
query_text = instance["question"]
correct_answer = instance["answer"]
search_results = await answer_resolver(query_text, system_prompt)
search_results = await answer_resolver(query_text)
answers.append(
{

View file

@ -48,8 +48,9 @@ async def run_question_answering(
answer_generator = AnswerGeneratorExecutor()
answers = await answer_generator.question_answering_non_parallel(
questions=questions,
answer_resolver=question_answering_engine_options[params["qa_engine"]],
system_prompt=system_prompt,
answer_resolver=lambda query: question_answering_engine_options[params["qa_engine"]](
query, system_prompt
),
)
with open(params["answers_path"], "w", encoding="utf-8") as f:
json.dump(answers, f, ensure_ascii=False, indent=4)

View file

@ -12,11 +12,11 @@ async def test_answer_generation():
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
mock_answer_resolver = AsyncMock()
mock_answer_resolver.side_effect = lambda query, system_prompt: ["mock_answer"]
mock_answer_resolver.side_effect = lambda query: ["mock_answer"]
answer_generator = AnswerGeneratorExecutor()
answers = await answer_generator.question_answering_non_parallel(
questions=qa_pairs, answer_resolver=mock_answer_resolver, system_prompt="test.txt"
questions=qa_pairs, answer_resolver=mock_answer_resolver
)
assert len(answers) == len(qa_pairs)