diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py index 1d3686efb..70938a451 100644 --- a/cognee/eval_framework/answer_generation/run_question_answering_module.py +++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py @@ -1,6 +1,6 @@ import logging import json -from typing import List +from typing import List, Optional from cognee.eval_framework.answer_generation.answer_generation_executor import ( AnswerGeneratorExecutor, retriever_options, @@ -32,7 +32,7 @@ async def create_and_insert_answers_table(questions_payload): async def run_question_answering( - params: dict, system_prompt="answer_simple_question.txt" + params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None ) -> List[dict]: if params.get("answering_questions"): logging.info("Question answering started...") @@ -48,7 +48,9 @@ async def run_question_answering( answer_generator = AnswerGeneratorExecutor() answers = await answer_generator.question_answering_non_parallel( questions=questions, - retriever=retriever_options[params["qa_engine"]](system_prompt_path=system_prompt), + retriever=retriever_options[params["qa_engine"]]( + system_prompt_path=system_prompt, top_k=top_k + ), ) with open(params["answers_path"], "w", encoding="utf-8") as f: json.dump(answers, f, ensure_ascii=False, indent=4) diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index f2427f062..cf8600f27 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -13,15 +13,17 @@ class CompletionRetriever(BaseRetriever): self, user_prompt_path: str = "context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", + top_k: Optional[int] = 1, ): """Initialize retriever with optional custom prompt paths.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path + self.top_k = top_k if top_k is not None else 1 async def get_context(self, query: str) -> Any: """Retrieves relevant document chunks as context.""" vector_engine = get_vector_engine() - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1) + found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) if len(found_chunks) == 0: raise NoRelevantDataFound return found_chunks[0].payload["text"] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 709415fa7..80b8855d1 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -15,12 +15,12 @@ class GraphCompletionRetriever(BaseRetriever): self, user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", - top_k: int = 5, + top_k: Optional[int] = 5, ): """Initialize retriever with prompt paths and search parameters.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path - self.top_k = top_k + self.top_k = top_k if top_k is not None else 5 async def resolve_edges_to_text(self, retrieved_edges: list) -> str: """Converts retrieved graph edges into a human-readable string format.""" diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index 536bafe5d..76ed5f5d4 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -12,7 +12,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", summarize_prompt_path: str = "summarize_search_results.txt", - top_k: int = 5, + top_k: Optional[int] = 5, ): """Initialize retriever with default prompt paths and search parameters.""" super().__init__(