From 834cf8b11307f38a09b42660db493bdf2ddaa14c Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Tue, 21 Oct 2025 00:34:12 +0200 Subject: [PATCH] feat: create_enrichments.py --- cognee/tasks/feedback/create_enrichments.py | 145 ++++++++++++++++++ .../feedback/generate_improved_answers.py | 5 +- cognee/tasks/feedback/models.py | 3 +- .../feedback_enrichment_minimal_example.py | 12 +- 4 files changed, 158 insertions(+), 7 deletions(-) create mode 100644 cognee/tasks/feedback/create_enrichments.py diff --git a/cognee/tasks/feedback/create_enrichments.py b/cognee/tasks/feedback/create_enrichments.py new file mode 100644 index 000000000..99de162b4 --- /dev/null +++ b/cognee/tasks/feedback/create_enrichments.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from typing import Dict, List, Optional +from uuid import NAMESPACE_OID, uuid5 + +from cognee.infrastructure.llm import LLMGateway +from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt +from cognee.shared.logging_utils import get_logger +from cognee.modules.engine.models import NodeSet + +from .models import FeedbackEnrichment + + +logger = get_logger("create_enrichments") + + +def _validate_improved_answers(improved_answers: List[Dict]) -> bool: + """Validate that all items contain required fields for enrichment creation.""" + required_fields = [ + "question", + "answer", # This is the original answer field from feedback_interaction + "improved_answer", + "new_context", + "feedback_id", + "interaction_id", + ] + return all( + all(item.get(field) is not None for field in required_fields) for item in improved_answers + ) + + +def _validate_uuid_fields(improved_answers: List[Dict]) -> bool: + """Validate that feedback_id and interaction_id are valid UUID objects.""" + try: + for item in improved_answers: + feedback_id = item.get("feedback_id") + interaction_id = item.get("interaction_id") + if not isinstance(feedback_id, type(feedback_id)) or not isinstance( + interaction_id, type(interaction_id) + ): + return False + return True + except Exception: + return False + + +async def _generate_enrichment_report( + question: str, improved_answer: str, new_context: str, report_prompt_location: str +) -> str: + """Generate educational report using feedback report prompt.""" + try: + prompt_template = read_query_prompt(report_prompt_location) + rendered_prompt = prompt_template.format( + question=question, + improved_answer=improved_answer, + new_context=new_context, + ) + return await LLMGateway.acreate_structured_output( + text_input=rendered_prompt, + system_prompt="You are a helpful assistant that creates educational content.", + response_model=str, + ) + except Exception as exc: + logger.warning("Failed to generate enrichment report", error=str(exc), question=question) + return f"Educational content for: {question} - {improved_answer}" + + +async def _create_enrichment_datapoint( + improved_answer_item: Dict, + report_text: str, +) -> Optional[FeedbackEnrichment]: + """Create a single FeedbackEnrichment DataPoint with proper ID and nodeset assignment.""" + try: + question = improved_answer_item["question"] + improved_answer = improved_answer_item["improved_answer"] + + # Create nodeset following UserQAFeedback pattern + nodeset = NodeSet( + id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment" + ) + + enrichment = FeedbackEnrichment( + id=str(uuid5(NAMESPACE_OID, f"{question}_{improved_answer}")), + text=report_text, + question=question, + original_answer=improved_answer_item["answer"], # Use "answer" field + improved_answer=improved_answer, + feedback_id=improved_answer_item["feedback_id"], + interaction_id=improved_answer_item["interaction_id"], + belongs_to_set=nodeset, + ) + + return enrichment + except Exception as exc: + logger.error( + "Failed to create enrichment datapoint", + error=str(exc), + question=improved_answer_item.get("question"), + ) + return None + + +async def create_enrichments( + improved_answers: List[Dict], + report_prompt_location: str = "feedback_report_prompt.txt", +) -> List[FeedbackEnrichment]: + """Create FeedbackEnrichment DataPoint instances from improved answers.""" + if not improved_answers: + logger.info("No improved answers provided; returning empty list") + return [] + + if not _validate_improved_answers(improved_answers): + logger.error("Input validation failed; missing required fields") + return [] + + if not _validate_uuid_fields(improved_answers): + logger.error("UUID validation failed; invalid feedback_id or interaction_id") + return [] + + logger.info("Creating enrichments", count=len(improved_answers)) + + enrichments: List[FeedbackEnrichment] = [] + + for improved_answer_item in improved_answers: + question = improved_answer_item["question"] + improved_answer = improved_answer_item["improved_answer"] + new_context = improved_answer_item["new_context"] + + report_text = await _generate_enrichment_report( + question, improved_answer, new_context, report_prompt_location + ) + + enrichment = await _create_enrichment_datapoint(improved_answer_item, report_text) + + if enrichment: + enrichments.append(enrichment) + else: + logger.warning( + "Failed to create enrichment", + question=question, + interaction_id=improved_answer_item.get("interaction_id"), + ) + + logger.info("Created enrichments", successful=len(enrichments)) + return enrichments diff --git a/cognee/tasks/feedback/generate_improved_answers.py b/cognee/tasks/feedback/generate_improved_answers.py index de8b1cc0b..a4edab7c1 100644 --- a/cognee/tasks/feedback/generate_improved_answers.py +++ b/cognee/tasks/feedback/generate_improved_answers.py @@ -66,7 +66,10 @@ async def _generate_improved_answer_for_single_interaction( retrieved_context = await retriever.get_context(query_text) completion = await retriever.get_structured_completion( - query=query_text, context=retrieved_context, response_model=ImprovedAnswerResponse + query=query_text, + context=retrieved_context, + response_model=ImprovedAnswerResponse, + max_iter=1, ) new_context_text = await retriever.resolve_edges_to_text(retrieved_context) diff --git a/cognee/tasks/feedback/models.py b/cognee/tasks/feedback/models.py index 403bc0e13..ae1064709 100644 --- a/cognee/tasks/feedback/models.py +++ b/cognee/tasks/feedback/models.py @@ -2,7 +2,7 @@ from typing import List, Optional, Union from uuid import UUID from cognee.infrastructure.engine import DataPoint -from cognee.modules.engine.models import Entity +from cognee.modules.engine.models import Entity, NodeSet from cognee.tasks.temporal_graph.models import Event @@ -18,3 +18,4 @@ class FeedbackEnrichment(DataPoint): improved_answer: str feedback_id: UUID interaction_id: UUID + belongs_to_set: Optional[NodeSet] = None diff --git a/examples/python/feedback_enrichment_minimal_example.py b/examples/python/feedback_enrichment_minimal_example.py index a36a7af8a..9fbb84821 100644 --- a/examples/python/feedback_enrichment_minimal_example.py +++ b/examples/python/feedback_enrichment_minimal_example.py @@ -6,6 +6,7 @@ from cognee.modules.pipelines.tasks.task import Task from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers +from cognee.tasks.feedback.create_enrichments import create_enrichments CONVERSATION = [ @@ -48,11 +49,12 @@ async def run_question_and_submit_feedback(question_text: str) -> bool: async def run_feedback_enrichment_memify(last_n: int = 5): - """Execute memify with extraction and answer improvement tasks.""" + """Execute memify with extraction, answer improvement, and enrichment creation tasks.""" # Instantiate tasks with their own kwargs extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)] enrichment_tasks = [ - Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20) + Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20), + Task(create_enrichments), ] await cognee.memify( extraction_tasks=extraction_tasks, @@ -63,9 +65,9 @@ async def run_feedback_enrichment_memify(last_n: int = 5): async def main(): - # await initialize_conversation_and_graph(CONVERSATION) - # is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?") - is_correct = False + await initialize_conversation_and_graph(CONVERSATION) + is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?") + # is_correct = False if not is_correct: await run_feedback_enrichment_memify(last_n=5)