feat: create_enrichments.py
This commit is contained in:
parent
ce418828b4
commit
834cf8b113
4 changed files with 158 additions and 7 deletions
145
cognee/tasks/feedback/create_enrichments.py
Normal file
145
cognee/tasks/feedback/create_enrichments.py
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.engine.models import NodeSet
|
||||||
|
|
||||||
|
from .models import FeedbackEnrichment
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("create_enrichments")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_improved_answers(improved_answers: List[Dict]) -> bool:
|
||||||
|
"""Validate that all items contain required fields for enrichment creation."""
|
||||||
|
required_fields = [
|
||||||
|
"question",
|
||||||
|
"answer", # This is the original answer field from feedback_interaction
|
||||||
|
"improved_answer",
|
||||||
|
"new_context",
|
||||||
|
"feedback_id",
|
||||||
|
"interaction_id",
|
||||||
|
]
|
||||||
|
return all(
|
||||||
|
all(item.get(field) is not None for field in required_fields) for item in improved_answers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_uuid_fields(improved_answers: List[Dict]) -> bool:
|
||||||
|
"""Validate that feedback_id and interaction_id are valid UUID objects."""
|
||||||
|
try:
|
||||||
|
for item in improved_answers:
|
||||||
|
feedback_id = item.get("feedback_id")
|
||||||
|
interaction_id = item.get("interaction_id")
|
||||||
|
if not isinstance(feedback_id, type(feedback_id)) or not isinstance(
|
||||||
|
interaction_id, type(interaction_id)
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_enrichment_report(
|
||||||
|
question: str, improved_answer: str, new_context: str, report_prompt_location: str
|
||||||
|
) -> str:
|
||||||
|
"""Generate educational report using feedback report prompt."""
|
||||||
|
try:
|
||||||
|
prompt_template = read_query_prompt(report_prompt_location)
|
||||||
|
rendered_prompt = prompt_template.format(
|
||||||
|
question=question,
|
||||||
|
improved_answer=improved_answer,
|
||||||
|
new_context=new_context,
|
||||||
|
)
|
||||||
|
return await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=rendered_prompt,
|
||||||
|
system_prompt="You are a helpful assistant that creates educational content.",
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to generate enrichment report", error=str(exc), question=question)
|
||||||
|
return f"Educational content for: {question} - {improved_answer}"
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_enrichment_datapoint(
|
||||||
|
improved_answer_item: Dict,
|
||||||
|
report_text: str,
|
||||||
|
) -> Optional[FeedbackEnrichment]:
|
||||||
|
"""Create a single FeedbackEnrichment DataPoint with proper ID and nodeset assignment."""
|
||||||
|
try:
|
||||||
|
question = improved_answer_item["question"]
|
||||||
|
improved_answer = improved_answer_item["improved_answer"]
|
||||||
|
|
||||||
|
# Create nodeset following UserQAFeedback pattern
|
||||||
|
nodeset = NodeSet(
|
||||||
|
id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment"
|
||||||
|
)
|
||||||
|
|
||||||
|
enrichment = FeedbackEnrichment(
|
||||||
|
id=str(uuid5(NAMESPACE_OID, f"{question}_{improved_answer}")),
|
||||||
|
text=report_text,
|
||||||
|
question=question,
|
||||||
|
original_answer=improved_answer_item["answer"], # Use "answer" field
|
||||||
|
improved_answer=improved_answer,
|
||||||
|
feedback_id=improved_answer_item["feedback_id"],
|
||||||
|
interaction_id=improved_answer_item["interaction_id"],
|
||||||
|
belongs_to_set=nodeset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return enrichment
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"Failed to create enrichment datapoint",
|
||||||
|
error=str(exc),
|
||||||
|
question=improved_answer_item.get("question"),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def create_enrichments(
|
||||||
|
improved_answers: List[Dict],
|
||||||
|
report_prompt_location: str = "feedback_report_prompt.txt",
|
||||||
|
) -> List[FeedbackEnrichment]:
|
||||||
|
"""Create FeedbackEnrichment DataPoint instances from improved answers."""
|
||||||
|
if not improved_answers:
|
||||||
|
logger.info("No improved answers provided; returning empty list")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not _validate_improved_answers(improved_answers):
|
||||||
|
logger.error("Input validation failed; missing required fields")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not _validate_uuid_fields(improved_answers):
|
||||||
|
logger.error("UUID validation failed; invalid feedback_id or interaction_id")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info("Creating enrichments", count=len(improved_answers))
|
||||||
|
|
||||||
|
enrichments: List[FeedbackEnrichment] = []
|
||||||
|
|
||||||
|
for improved_answer_item in improved_answers:
|
||||||
|
question = improved_answer_item["question"]
|
||||||
|
improved_answer = improved_answer_item["improved_answer"]
|
||||||
|
new_context = improved_answer_item["new_context"]
|
||||||
|
|
||||||
|
report_text = await _generate_enrichment_report(
|
||||||
|
question, improved_answer, new_context, report_prompt_location
|
||||||
|
)
|
||||||
|
|
||||||
|
enrichment = await _create_enrichment_datapoint(improved_answer_item, report_text)
|
||||||
|
|
||||||
|
if enrichment:
|
||||||
|
enrichments.append(enrichment)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to create enrichment",
|
||||||
|
question=question,
|
||||||
|
interaction_id=improved_answer_item.get("interaction_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Created enrichments", successful=len(enrichments))
|
||||||
|
return enrichments
|
||||||
|
|
@ -66,7 +66,10 @@ async def _generate_improved_answer_for_single_interaction(
|
||||||
|
|
||||||
retrieved_context = await retriever.get_context(query_text)
|
retrieved_context = await retriever.get_context(query_text)
|
||||||
completion = await retriever.get_structured_completion(
|
completion = await retriever.get_structured_completion(
|
||||||
query=query_text, context=retrieved_context, response_model=ImprovedAnswerResponse
|
query=query_text,
|
||||||
|
context=retrieved_context,
|
||||||
|
response_model=ImprovedAnswerResponse,
|
||||||
|
max_iter=1,
|
||||||
)
|
)
|
||||||
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
|
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from typing import List, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.engine.models import Entity
|
from cognee.modules.engine.models import Entity, NodeSet
|
||||||
from cognee.tasks.temporal_graph.models import Event
|
from cognee.tasks.temporal_graph.models import Event
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,3 +18,4 @@ class FeedbackEnrichment(DataPoint):
|
||||||
improved_answer: str
|
improved_answer: str
|
||||||
feedback_id: UUID
|
feedback_id: UUID
|
||||||
interaction_id: UUID
|
interaction_id: UUID
|
||||||
|
belongs_to_set: Optional[NodeSet] = None
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from cognee.modules.pipelines.tasks.task import Task
|
||||||
|
|
||||||
from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
|
from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
|
||||||
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
|
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
|
||||||
|
from cognee.tasks.feedback.create_enrichments import create_enrichments
|
||||||
|
|
||||||
|
|
||||||
CONVERSATION = [
|
CONVERSATION = [
|
||||||
|
|
@ -48,11 +49,12 @@ async def run_question_and_submit_feedback(question_text: str) -> bool:
|
||||||
|
|
||||||
|
|
||||||
async def run_feedback_enrichment_memify(last_n: int = 5):
|
async def run_feedback_enrichment_memify(last_n: int = 5):
|
||||||
"""Execute memify with extraction and answer improvement tasks."""
|
"""Execute memify with extraction, answer improvement, and enrichment creation tasks."""
|
||||||
# Instantiate tasks with their own kwargs
|
# Instantiate tasks with their own kwargs
|
||||||
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
|
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
|
||||||
enrichment_tasks = [
|
enrichment_tasks = [
|
||||||
Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20)
|
Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20),
|
||||||
|
Task(create_enrichments),
|
||||||
]
|
]
|
||||||
await cognee.memify(
|
await cognee.memify(
|
||||||
extraction_tasks=extraction_tasks,
|
extraction_tasks=extraction_tasks,
|
||||||
|
|
@ -63,9 +65,9 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# await initialize_conversation_and_graph(CONVERSATION)
|
await initialize_conversation_and_graph(CONVERSATION)
|
||||||
# is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
|
is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
|
||||||
is_correct = False
|
# is_correct = False
|
||||||
if not is_correct:
|
if not is_correct:
|
||||||
await run_feedback_enrichment_memify(last_n=5)
|
await run_feedback_enrichment_memify(last_n=5)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue