From ac0156514d6dbc237559ac44a937fef330a679c3 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 10 Mar 2025 10:55:31 +0100 Subject: [PATCH] feat: COG-1523 add top_k in run_question_answering (#625) ## Description - Expose top_k as an optional argument of run_question_answering - Update retrievers to handle the parameters ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin ## Summary by CodeRabbit - **New Features** - Enhanced answer generation and document retrieval capabilities by introducing an optional parameter that allows users to specify the number of top results. This improvement adds flexibility when retrieving question responses and associated context, adapting the output based on user preference. --- .../answer_generation/run_question_answering_module.py | 8 +++++--- cognee/modules/retrieval/completion_retriever.py | 4 +++- cognee/modules/retrieval/graph_completion_retriever.py | 4 ++-- .../retrieval/graph_summary_completion_retriever.py | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py index 1d3686efb..70938a451 100644 --- a/cognee/eval_framework/answer_generation/run_question_answering_module.py +++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py @@ -1,6 +1,6 @@ import logging import json -from typing import List +from typing import List, Optional from cognee.eval_framework.answer_generation.answer_generation_executor import ( AnswerGeneratorExecutor, retriever_options, @@ -32,7 +32,7 @@ async def create_and_insert_answers_table(questions_payload): async def run_question_answering( - params: dict, system_prompt="answer_simple_question.txt" + params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None ) -> List[dict]: if params.get("answering_questions"): logging.info("Question answering started...") @@ -48,7 +48,9 @@ async def run_question_answering( answer_generator = AnswerGeneratorExecutor() answers = await answer_generator.question_answering_non_parallel( questions=questions, - retriever=retriever_options[params["qa_engine"]](system_prompt_path=system_prompt), + retriever=retriever_options[params["qa_engine"]]( + system_prompt_path=system_prompt, top_k=top_k + ), ) with open(params["answers_path"], "w", encoding="utf-8") as f: json.dump(answers, f, ensure_ascii=False, indent=4) diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index f2427f062..cf8600f27 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -13,15 +13,17 @@ class CompletionRetriever(BaseRetriever): self, user_prompt_path: str = "context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + top_k: Optional[int] = 1, ): """Initialize retriever with optional custom prompt paths.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path + self.top_k = top_k if top_k is not None else 1 async def get_context(self, query: str) -> Any: """Retrieves relevant document chunks as context.""" vector_engine = get_vector_engine() - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1) + found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) if len(found_chunks) == 0: raise NoRelevantDataFound return found_chunks[0].payload["text"] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 709415fa7..80b8855d1 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -15,12 +15,12 @@ class GraphCompletionRetriever(BaseRetriever): self, user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", - top_k: int = 5, + top_k: Optional[int] = 5, ): """Initialize retriever with prompt paths and search parameters.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path - self.top_k = top_k + self.top_k = top_k if top_k is not None else 5 async def resolve_edges_to_text(self, retrieved_edges: list) -> str: """Converts retrieved graph edges into a human-readable string format.""" diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index 536bafe5d..76ed5f5d4 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -12,7 +12,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", summarize_prompt_path: str = "summarize_search_results.txt", - top_k: int = 5, + top_k: Optional[int] = 5, ): """Initialize retriever with default prompt paths and search parameters.""" super().__init__(