cognee/cognee/tasks/feedback/generate_improved_answers.py

130 lines
4.3 KiB
Python

from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel
from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from .models import FeedbackEnrichment
class ImprovedAnswerResponse(BaseModel):
"""Response model for improved answer generation containing answer and explanation."""
answer: str
explanation: str
logger = get_logger("generate_improved_answers")
def _validate_input_data(enrichments: List[FeedbackEnrichment]) -> bool:
"""Validate that input contains required fields for all enrichments."""
return all(
enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.context is not None
and enrichment.feedback_text is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
for enrichment in enrichments
)
def _render_reaction_prompt(
question: str, context: str, wrong_answer: str, negative_feedback: str
) -> str:
"""Render the feedback reaction prompt with provided variables."""
prompt_template = read_query_prompt("feedback_reaction_prompt.txt")
return prompt_template.format(
question=question,
context=context,
wrong_answer=wrong_answer,
negative_feedback=negative_feedback,
)
async def _generate_improved_answer_for_single_interaction(
enrichment: FeedbackEnrichment, retriever, reaction_prompt_location: str
) -> Optional[FeedbackEnrichment]:
"""Generate improved answer for a single enrichment using structured retriever completion."""
try:
query_text = _render_reaction_prompt(
enrichment.question,
enrichment.context,
enrichment.original_answer,
enrichment.feedback_text,
)
retrieved_context = await retriever.get_context(query_text)
completion = await retriever.get_completion(
query=query_text,
context=retrieved_context,
response_model=ImprovedAnswerResponse,
max_iter=4,
)
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
if completion:
enrichment.improved_answer = completion[0].answer
enrichment.new_context = new_context_text
enrichment.explanation = completion[0].explanation
return enrichment
else:
logger.warning(
"Failed to get structured completion from retriever", question=enrichment.question
)
return None
except Exception as exc: # noqa: BLE001
logger.error(
"Failed to generate improved answer",
error=str(exc),
question=enrichment.question,
)
return None
async def generate_improved_answers(
enrichments: List[FeedbackEnrichment],
top_k: int = 20,
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
) -> List[FeedbackEnrichment]:
"""Generate improved answers using CoT retriever and LLM."""
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
if not _validate_input_data(enrichments):
logger.error("Input data validation failed; missing required fields")
return []
retriever = GraphCompletionCotRetriever(
top_k=top_k,
save_interaction=False,
user_prompt_path="graph_context_for_question.txt",
system_prompt_path="answer_simple_question.txt",
)
improved_answers: List[FeedbackEnrichment] = []
for enrichment in enrichments:
result = await _generate_improved_answer_for_single_interaction(
enrichment, retriever, reaction_prompt_location
)
if result:
improved_answers.append(result)
else:
logger.warning(
"Failed to generate improved answer",
question=enrichment.question,
interaction_id=enrichment.interaction_id,
)
logger.info("Generated improved answers", count=len(improved_answers))
return improved_answers