feat: create_enrichments.py

This commit is contained in:
lxobr 2025-10-21 00:34:12 +02:00
parent ce418828b4
commit 834cf8b113
4 changed files with 158 additions and 7 deletions

View file

@ -0,0 +1,145 @@
from __future__ import annotations
from typing import Dict, List, Optional
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.models import NodeSet
from .models import FeedbackEnrichment
logger = get_logger("create_enrichments")
def _validate_improved_answers(improved_answers: List[Dict]) -> bool:
"""Validate that all items contain required fields for enrichment creation."""
required_fields = [
"question",
"answer", # This is the original answer field from feedback_interaction
"improved_answer",
"new_context",
"feedback_id",
"interaction_id",
]
return all(
all(item.get(field) is not None for field in required_fields) for item in improved_answers
)
def _validate_uuid_fields(improved_answers: List[Dict]) -> bool:
"""Validate that feedback_id and interaction_id are valid UUID objects."""
try:
for item in improved_answers:
feedback_id = item.get("feedback_id")
interaction_id = item.get("interaction_id")
if not isinstance(feedback_id, type(feedback_id)) or not isinstance(
interaction_id, type(interaction_id)
):
return False
return True
except Exception:
return False
async def _generate_enrichment_report(
question: str, improved_answer: str, new_context: str, report_prompt_location: str
) -> str:
"""Generate educational report using feedback report prompt."""
try:
prompt_template = read_query_prompt(report_prompt_location)
rendered_prompt = prompt_template.format(
question=question,
improved_answer=improved_answer,
new_context=new_context,
)
return await LLMGateway.acreate_structured_output(
text_input=rendered_prompt,
system_prompt="You are a helpful assistant that creates educational content.",
response_model=str,
)
except Exception as exc:
logger.warning("Failed to generate enrichment report", error=str(exc), question=question)
return f"Educational content for: {question} - {improved_answer}"
async def _create_enrichment_datapoint(
improved_answer_item: Dict,
report_text: str,
) -> Optional[FeedbackEnrichment]:
"""Create a single FeedbackEnrichment DataPoint with proper ID and nodeset assignment."""
try:
question = improved_answer_item["question"]
improved_answer = improved_answer_item["improved_answer"]
# Create nodeset following UserQAFeedback pattern
nodeset = NodeSet(
id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment"
)
enrichment = FeedbackEnrichment(
id=str(uuid5(NAMESPACE_OID, f"{question}_{improved_answer}")),
text=report_text,
question=question,
original_answer=improved_answer_item["answer"], # Use "answer" field
improved_answer=improved_answer,
feedback_id=improved_answer_item["feedback_id"],
interaction_id=improved_answer_item["interaction_id"],
belongs_to_set=nodeset,
)
return enrichment
except Exception as exc:
logger.error(
"Failed to create enrichment datapoint",
error=str(exc),
question=improved_answer_item.get("question"),
)
return None
async def create_enrichments(
improved_answers: List[Dict],
report_prompt_location: str = "feedback_report_prompt.txt",
) -> List[FeedbackEnrichment]:
"""Create FeedbackEnrichment DataPoint instances from improved answers."""
if not improved_answers:
logger.info("No improved answers provided; returning empty list")
return []
if not _validate_improved_answers(improved_answers):
logger.error("Input validation failed; missing required fields")
return []
if not _validate_uuid_fields(improved_answers):
logger.error("UUID validation failed; invalid feedback_id or interaction_id")
return []
logger.info("Creating enrichments", count=len(improved_answers))
enrichments: List[FeedbackEnrichment] = []
for improved_answer_item in improved_answers:
question = improved_answer_item["question"]
improved_answer = improved_answer_item["improved_answer"]
new_context = improved_answer_item["new_context"]
report_text = await _generate_enrichment_report(
question, improved_answer, new_context, report_prompt_location
)
enrichment = await _create_enrichment_datapoint(improved_answer_item, report_text)
if enrichment:
enrichments.append(enrichment)
else:
logger.warning(
"Failed to create enrichment",
question=question,
interaction_id=improved_answer_item.get("interaction_id"),
)
logger.info("Created enrichments", successful=len(enrichments))
return enrichments

View file

@ -66,7 +66,10 @@ async def _generate_improved_answer_for_single_interaction(
retrieved_context = await retriever.get_context(query_text) retrieved_context = await retriever.get_context(query_text)
completion = await retriever.get_structured_completion( completion = await retriever.get_structured_completion(
query=query_text, context=retrieved_context, response_model=ImprovedAnswerResponse query=query_text,
context=retrieved_context,
response_model=ImprovedAnswerResponse,
max_iter=1,
) )
new_context_text = await retriever.resolve_edges_to_text(retrieved_context) new_context_text = await retriever.resolve_edges_to_text(retrieved_context)

View file

@ -2,7 +2,7 @@ from typing import List, Optional, Union
from uuid import UUID from uuid import UUID
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models import Entity from cognee.modules.engine.models import Entity, NodeSet
from cognee.tasks.temporal_graph.models import Event from cognee.tasks.temporal_graph.models import Event
@ -18,3 +18,4 @@ class FeedbackEnrichment(DataPoint):
improved_answer: str improved_answer: str
feedback_id: UUID feedback_id: UUID
interaction_id: UUID interaction_id: UUID
belongs_to_set: Optional[NodeSet] = None

View file

@ -6,6 +6,7 @@ from cognee.modules.pipelines.tasks.task import Task
from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
from cognee.tasks.feedback.create_enrichments import create_enrichments
CONVERSATION = [ CONVERSATION = [
@ -48,11 +49,12 @@ async def run_question_and_submit_feedback(question_text: str) -> bool:
async def run_feedback_enrichment_memify(last_n: int = 5): async def run_feedback_enrichment_memify(last_n: int = 5):
"""Execute memify with extraction and answer improvement tasks.""" """Execute memify with extraction, answer improvement, and enrichment creation tasks."""
# Instantiate tasks with their own kwargs # Instantiate tasks with their own kwargs
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)] extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
enrichment_tasks = [ enrichment_tasks = [
Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20) Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20),
Task(create_enrichments),
] ]
await cognee.memify( await cognee.memify(
extraction_tasks=extraction_tasks, extraction_tasks=extraction_tasks,
@ -63,9 +65,9 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
async def main(): async def main():
# await initialize_conversation_and_graph(CONVERSATION) await initialize_conversation_and_graph(CONVERSATION)
# is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?") is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
is_correct = False # is_correct = False
if not is_correct: if not is_correct:
await run_feedback_enrichment_memify(last_n=5) await run_feedback_enrichment_memify(last_n=5)