feat: generate improved answers temp
This commit is contained in:
parent
78fca9feb7
commit
97eb89386e
1 changed files with 132 additions and 0 deletions
132
cognee/tasks/feedback/generate_improved_answers.py
Normal file
132
cognee/tasks/feedback/generate_improved_answers.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .utils import create_retriever
|
||||
|
||||
|
||||
class ImprovedAnswerResponse(BaseModel):
|
||||
"""Response model for improved answer generation containing answer and explanation."""
|
||||
|
||||
answer: str
|
||||
explanation: str
|
||||
|
||||
|
||||
logger = get_logger("generate_improved_answers")
|
||||
|
||||
|
||||
def _validate_input_data(feedback_interactions: List[Dict]) -> bool:
|
||||
"""Validate that input contains required fields for all items."""
|
||||
required_fields = [
|
||||
"question",
|
||||
"answer",
|
||||
"context",
|
||||
"feedback_text",
|
||||
"feedback_id",
|
||||
"interaction_id",
|
||||
]
|
||||
return all(
|
||||
all(item.get(field) is not None for field in required_fields)
|
||||
for item in feedback_interactions
|
||||
)
|
||||
|
||||
|
||||
def _render_reaction_prompt(
|
||||
question: str, context: str, wrong_answer: str, negative_feedback: str
|
||||
) -> str:
|
||||
"""Render the feedback reaction prompt with provided variables."""
|
||||
prompt_template = read_query_prompt("feedback_reaction_prompt.txt")
|
||||
return prompt_template.format(
|
||||
question=question,
|
||||
context=context,
|
||||
wrong_answer=wrong_answer,
|
||||
negative_feedback=negative_feedback,
|
||||
)
|
||||
|
||||
|
||||
async def _generate_improved_answer_for_single_interaction(
|
||||
feedback_interaction: Dict, retriever, reaction_prompt_location: str
|
||||
) -> Optional[Dict]:
|
||||
"""Generate improved answer for a single feedback-interaction pair using structured retriever completion."""
|
||||
try:
|
||||
question_text = feedback_interaction["question"]
|
||||
original_answer_text = feedback_interaction["answer"]
|
||||
context_text = feedback_interaction["context"]
|
||||
feedback_text = feedback_interaction["feedback_text"]
|
||||
|
||||
query_text = _render_reaction_prompt(
|
||||
question_text, context_text, original_answer_text, feedback_text
|
||||
)
|
||||
|
||||
retrieved_context = await retriever.get_context(query_text)
|
||||
completion, new_context_text = await retriever.get_structured_completion(
|
||||
query=query_text, context=retrieved_context, response_model=ImprovedAnswerResponse
|
||||
)
|
||||
|
||||
if completion:
|
||||
return {
|
||||
**feedback_interaction,
|
||||
"improved_answer": completion.answer,
|
||||
"new_context": new_context_text,
|
||||
"explanation": completion.explanation,
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to get structured completion from retriever", question=question_text
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(
|
||||
"Failed to generate improved answer",
|
||||
error=str(exc),
|
||||
question=feedback_interaction.get("question"),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def generate_improved_answers(
|
||||
feedback_interactions: List[Dict],
|
||||
retriever_name: str = "graph_completion_cot",
|
||||
top_k: int = 20,
|
||||
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
|
||||
) -> List[Dict]:
|
||||
"""Generate improved answers using configurable retriever and LLM."""
|
||||
if not feedback_interactions:
|
||||
logger.info("No feedback interactions provided; returning empty list")
|
||||
return []
|
||||
|
||||
if not _validate_input_data(feedback_interactions):
|
||||
logger.error("Input data validation failed; missing required fields")
|
||||
return []
|
||||
|
||||
retriever = create_retriever(
|
||||
retriever_name=retriever_name,
|
||||
top_k=top_k,
|
||||
user_prompt_path="graph_context_for_question.txt",
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
)
|
||||
|
||||
improved_answers: List[Dict] = []
|
||||
successful_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for feedback_interaction in feedback_interactions:
|
||||
result = await _generate_improved_answer_for_single_interaction(
|
||||
feedback_interaction, retriever, reaction_prompt_location
|
||||
)
|
||||
|
||||
if result:
|
||||
improved_answers.append(result)
|
||||
successful_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
logger.info("Generated improved answers", successful=successful_count, failed=failed_count)
|
||||
return improved_answers
|
||||
Loading…
Add table
Reference in a new issue