feat: COG-1523 add top_k in run_question_answering (#625)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> - Expose top_k as an optional argument of run_question_answering - Update retrievers to handle the parameters ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced answer generation and document retrieval capabilities by introducing an optional parameter that allows users to specify the number of top results. This improvement adds flexibility when retrieving question responses and associated context, adapting the output based on user preference. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
56427f287e
commit
ac0156514d
4 changed files with 11 additions and 7 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import json
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from cognee.eval_framework.answer_generation.answer_generation_executor import (
|
||||
AnswerGeneratorExecutor,
|
||||
retriever_options,
|
||||
|
|
@ -32,7 +32,7 @@ async def create_and_insert_answers_table(questions_payload):
|
|||
|
||||
|
||||
async def run_question_answering(
|
||||
params: dict, system_prompt="answer_simple_question.txt"
|
||||
params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
if params.get("answering_questions"):
|
||||
logging.info("Question answering started...")
|
||||
|
|
@ -48,7 +48,9 @@ async def run_question_answering(
|
|||
answer_generator = AnswerGeneratorExecutor()
|
||||
answers = await answer_generator.question_answering_non_parallel(
|
||||
questions=questions,
|
||||
retriever=retriever_options[params["qa_engine"]](system_prompt_path=system_prompt),
|
||||
retriever=retriever_options[params["qa_engine"]](
|
||||
system_prompt_path=system_prompt, top_k=top_k
|
||||
),
|
||||
)
|
||||
with open(params["answers_path"], "w", encoding="utf-8") as f:
|
||||
json.dump(answers, f, ensure_ascii=False, indent=4)
|
||||
|
|
|
|||
|
|
@ -13,15 +13,17 @@ class CompletionRetriever(BaseRetriever):
|
|||
self,
|
||||
user_prompt_path: str = "context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: Optional[int] = 1,
|
||||
):
|
||||
"""Initialize retriever with optional custom prompt paths."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 1
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves relevant document chunks as context."""
|
||||
vector_engine = get_vector_engine()
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1)
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||
if len(found_chunks) == 0:
|
||||
raise NoRelevantDataFound
|
||||
return found_chunks[0].payload["text"]
|
||||
|
|
|
|||
|
|
@ -15,12 +15,12 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
self,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 5,
|
||||
top_k: Optional[int] = 5,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
"""Converts retrieved graph edges into a human-readable string format."""
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
summarize_prompt_path: str = "summarize_search_results.txt",
|
||||
top_k: int = 5,
|
||||
top_k: Optional[int] = 5,
|
||||
):
|
||||
"""Initialize retriever with default prompt paths and search parameters."""
|
||||
super().__init__(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue