feat: COG-1523 add top_k in run_question_answering (#625)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->
- Expose top_k as an optional argument of run_question_answering
- Update retrievers to handle the parameters

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced answer generation and document retrieval capabilities by
introducing an optional parameter that allows users to specify the
number of top results. This improvement adds flexibility when retrieving
question responses and associated context, adapting the output based on
user preference.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
lxobr 2025-03-10 10:55:31 +01:00 committed by GitHub
parent 56427f287e
commit ac0156514d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 11 additions and 7 deletions

View file

@ -1,6 +1,6 @@
import logging
import json
from typing import List
from typing import List, Optional
from cognee.eval_framework.answer_generation.answer_generation_executor import (
AnswerGeneratorExecutor,
retriever_options,
@ -32,7 +32,7 @@ async def create_and_insert_answers_table(questions_payload):
async def run_question_answering(
params: dict, system_prompt="answer_simple_question.txt"
params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
) -> List[dict]:
if params.get("answering_questions"):
logging.info("Question answering started...")
@ -48,7 +48,9 @@ async def run_question_answering(
answer_generator = AnswerGeneratorExecutor()
answers = await answer_generator.question_answering_non_parallel(
questions=questions,
retriever=retriever_options[params["qa_engine"]](system_prompt_path=system_prompt),
retriever=retriever_options[params["qa_engine"]](
system_prompt_path=system_prompt, top_k=top_k
),
)
with open(params["answers_path"], "w", encoding="utf-8") as f:
json.dump(answers, f, ensure_ascii=False, indent=4)

View file

@ -13,15 +13,17 @@ class CompletionRetriever(BaseRetriever):
self,
user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: Optional[int] = 1,
):
"""Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1
async def get_context(self, query: str) -> Any:
"""Retrieves relevant document chunks as context."""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1)
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
if len(found_chunks) == 0:
raise NoRelevantDataFound
return found_chunks[0].payload["text"]

View file

@ -15,12 +15,12 @@ class GraphCompletionRetriever(BaseRetriever):
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 5,
top_k: Optional[int] = 5,
):
"""Initialize retriever with prompt paths and search parameters."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k
self.top_k = top_k if top_k is not None else 5
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""Converts retrieved graph edges into a human-readable string format."""

View file

@ -12,7 +12,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
summarize_prompt_path: str = "summarize_search_results.txt",
top_k: int = 5,
top_k: Optional[int] = 5,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(