diff --git a/cognee/tasks/feedback/__init__.py b/cognee/tasks/feedback/__init__.py new file mode 100644 index 000000000..25102dfb4 --- /dev/null +++ b/cognee/tasks/feedback/__init__.py @@ -0,0 +1,13 @@ +from .extract_feedback_interactions import extract_feedback_interactions +from .generate_improved_answers import generate_improved_answers +from .create_enrichments import create_enrichments +from .link_enrichments_to_feedback import link_enrichments_to_feedback +from .models import FeedbackEnrichment + +__all__ = [ + "extract_feedback_interactions", + "generate_improved_answers", + "create_enrichments", + "link_enrichments_to_feedback", + "FeedbackEnrichment", +] diff --git a/cognee/tasks/feedback/create_enrichments.py b/cognee/tasks/feedback/create_enrichments.py index 9ffbf9f88..ee15e9797 100644 --- a/cognee/tasks/feedback/create_enrichments.py +++ b/cognee/tasks/feedback/create_enrichments.py @@ -14,36 +14,19 @@ from .models import FeedbackEnrichment logger = get_logger("create_enrichments") -def _validate_improved_answers(improved_answers: List[Dict]) -> bool: - """Validate that all items contain required fields for enrichment creation.""" - required_fields = [ - "question", - "answer", # This is the original answer field from feedback_interaction - "improved_answer", - "new_context", - "feedback_id", - "interaction_id", - ] +def _validate_enrichments(enrichments: List[FeedbackEnrichment]) -> bool: + """Validate that all enrichments contain required fields for completion.""" return all( - all(item.get(field) is not None for field in required_fields) for item in improved_answers + enrichment.question is not None + and enrichment.original_answer is not None + and enrichment.improved_answer is not None + and enrichment.new_context is not None + and enrichment.feedback_id is not None + and enrichment.interaction_id is not None + for enrichment in enrichments ) -def _validate_uuid_fields(improved_answers: List[Dict]) -> bool: - """Validate that feedback_id and interaction_id are valid UUID objects.""" - try: - for item in improved_answers: - feedback_id = item.get("feedback_id") - interaction_id = item.get("interaction_id") - if not isinstance(feedback_id, type(feedback_id)) or not isinstance( - interaction_id, type(interaction_id) - ): - return False - return True - except Exception: - return False - - async def _generate_enrichment_report( question: str, improved_answer: str, new_context: str, report_prompt_location: str ) -> str: @@ -65,80 +48,37 @@ async def _generate_enrichment_report( return f"Educational content for: {question} - {improved_answer}" -async def _create_enrichment_datapoint( - improved_answer_item: Dict, - report_text: str, - nodeset: NodeSet, -) -> Optional[FeedbackEnrichment]: - """Create a single FeedbackEnrichment DataPoint with proper ID and nodeset assignment.""" - try: - question = improved_answer_item["question"] - improved_answer = improved_answer_item["improved_answer"] - - enrichment = FeedbackEnrichment( - id=str(uuid5(NAMESPACE_OID, f"{question}_{improved_answer}")), - text=report_text, - question=question, - original_answer=improved_answer_item["answer"], # Use "answer" field - improved_answer=improved_answer, - feedback_id=improved_answer_item["feedback_id"], - interaction_id=improved_answer_item["interaction_id"], - belongs_to_set=[nodeset], - ) - - return enrichment - except Exception as exc: - logger.error( - "Failed to create enrichment datapoint", - error=str(exc), - question=improved_answer_item.get("question"), - ) - return None - - async def create_enrichments( - improved_answers: List[Dict], + enrichments: List[FeedbackEnrichment], report_prompt_location: str = "feedback_report_prompt.txt", ) -> List[FeedbackEnrichment]: - """Create FeedbackEnrichment DataPoint instances from improved answers.""" - if not improved_answers: - logger.info("No improved answers provided; returning empty list") + """Fill text and belongs_to_set fields of existing FeedbackEnrichment DataPoints.""" + if not enrichments: + logger.info("No enrichments provided; returning empty list") return [] - if not _validate_improved_answers(improved_answers): + if not _validate_enrichments(enrichments): logger.error("Input validation failed; missing required fields") return [] - if not _validate_uuid_fields(improved_answers): - logger.error("UUID validation failed; invalid feedback_id or interaction_id") - return [] + logger.info("Completing enrichments", count=len(enrichments)) - logger.info("Creating enrichments", count=len(improved_answers)) - - # Create nodeset once for all enrichments nodeset = NodeSet(id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment") - enrichments: List[FeedbackEnrichment] = [] - - for improved_answer_item in improved_answers: - question = improved_answer_item["question"] - improved_answer = improved_answer_item["improved_answer"] - new_context = improved_answer_item["new_context"] + completed_enrichments: List[FeedbackEnrichment] = [] + for enrichment in enrichments: report_text = await _generate_enrichment_report( - question, improved_answer, new_context, report_prompt_location + enrichment.question, + enrichment.improved_answer, + enrichment.new_context, + report_prompt_location, ) - enrichment = await _create_enrichment_datapoint(improved_answer_item, report_text, nodeset) + enrichment.text = report_text + enrichment.belongs_to_set = [nodeset] - if enrichment: - enrichments.append(enrichment) - else: - logger.warning( - "Failed to create enrichment", - question=question, - interaction_id=improved_answer_item.get("interaction_id"), - ) + completed_enrichments.append(enrichment) - logger.info("Created enrichments", successful=len(enrichments)) - return enrichments + logger.info("Completed enrichments", successful=len(completed_enrichments)) + return completed_enrichments diff --git a/cognee/tasks/feedback/extract_feedback_interactions.py b/cognee/tasks/feedback/extract_feedback_interactions.py index 44f139d70..e5d03026e 100644 --- a/cognee/tasks/feedback/extract_feedback_interactions.py +++ b/cognee/tasks/feedback/extract_feedback_interactions.py @@ -7,8 +7,10 @@ from cognee.infrastructure.llm import LLMGateway from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.graph import get_graph_engine +from uuid import uuid5, NAMESPACE_OID from .utils import filter_negative_feedback +from .models import FeedbackEnrichment logger = get_logger("extract_feedback_interactions") @@ -49,11 +51,8 @@ def _match_feedback_nodes_to_interactions_by_edges( feedback_nodes: List, interaction_nodes: List, graph_edges: List ) -> List[Tuple[Tuple, Tuple]]: """Match feedback to interactions using gives_feedback_to edges.""" - # Build single lookup maps using normalized Cognee IDs interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes} feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes} - - # Filter to only gives_feedback_to edges feedback_edges = [ (source_id, target_id) for source_id, target_id, rel, _ in graph_edges @@ -103,23 +102,22 @@ async def _generate_human_readable_context_summary( return raw_context_text or "" -def _has_required_feedback_fields(record: Dict) -> bool: - """Validate required fields exist in the item dict.""" - required_fields = [ - "question", - "answer", - "context", - "feedback_text", - "feedback_id", - "interaction_id", - ] - return all(record.get(field_name) is not None for field_name in required_fields) +def _has_required_feedback_fields(enrichment: FeedbackEnrichment) -> bool: + """Validate required fields exist in the FeedbackEnrichment DataPoint.""" + return ( + enrichment.question is not None + and enrichment.original_answer is not None + and enrichment.context is not None + and enrichment.feedback_text is not None + and enrichment.feedback_id is not None + and enrichment.interaction_id is not None + ) async def _build_feedback_interaction_record( feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict -) -> Optional[Dict]: - """Build a single feedback-interaction record with context summary.""" +) -> Optional[FeedbackEnrichment]: + """Build a single FeedbackEnrichment DataPoint with context summary.""" try: question_text = interaction_props.get("question") original_answer_text = interaction_props.get("answer") @@ -130,17 +128,23 @@ async def _build_feedback_interaction_record( question_text or "", raw_context_text ) - feedback_interaction_record = { - "question": question_text, - "answer": original_answer_text, - "context": context_summary_text, - "feedback_text": feedback_text, - "feedback_id": UUID(str(feedback_node_id)), - "interaction_id": UUID(str(interaction_node_id)), - } + enrichment = FeedbackEnrichment( + id=str(uuid5(NAMESPACE_OID, f"{question_text}_{interaction_node_id}")), + text="", + question=question_text, + original_answer=original_answer_text, + improved_answer="", + feedback_id=UUID(str(feedback_node_id)), + interaction_id=UUID(str(interaction_node_id)), + belongs_to_set=None, + context=context_summary_text, + feedback_text=feedback_text, + new_context="", + explanation="", + ) - if _has_required_feedback_fields(feedback_interaction_record): - return feedback_interaction_record + if _has_required_feedback_fields(enrichment): + return enrichment else: logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id)) return None @@ -151,9 +155,9 @@ async def _build_feedback_interaction_record( async def _build_feedback_interaction_records( matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]], -) -> List[Dict]: - """Build all feedback-interaction records from matched pairs.""" - feedback_interaction_records: List[Dict] = [] +) -> List[FeedbackEnrichment]: + """Build all FeedbackEnrichment DataPoints from matched pairs.""" + feedback_interaction_records: List[FeedbackEnrichment] = [] for (feedback_node_id, feedback_props), ( interaction_node_id, interaction_props, @@ -168,8 +172,8 @@ async def _build_feedback_interaction_records( async def extract_feedback_interactions( subgraphs: List, last_n: Optional[int] = None -) -> List[Dict]: - """Extract negative feedback-interaction pairs; fetch internally and use last_n param for limiting.""" +) -> List[FeedbackEnrichment]: + """Extract negative feedback-interaction pairs and create FeedbackEnrichment DataPoints.""" graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data() if not graph_nodes: return [] diff --git a/cognee/tasks/feedback/generate_improved_answers.py b/cognee/tasks/feedback/generate_improved_answers.py index a4edab7c1..10059df7e 100644 --- a/cognee/tasks/feedback/generate_improved_answers.py +++ b/cognee/tasks/feedback/generate_improved_answers.py @@ -9,6 +9,7 @@ from cognee.modules.graph.utils import resolve_edges_to_text from cognee.shared.logging_utils import get_logger from .utils import create_retriever +from .models import FeedbackEnrichment class ImprovedAnswerResponse(BaseModel): @@ -21,19 +22,16 @@ class ImprovedAnswerResponse(BaseModel): logger = get_logger("generate_improved_answers") -def _validate_input_data(feedback_interactions: List[Dict]) -> bool: - """Validate that input contains required fields for all items.""" - required_fields = [ - "question", - "answer", - "context", - "feedback_text", - "feedback_id", - "interaction_id", - ] +def _validate_input_data(enrichments: List[FeedbackEnrichment]) -> bool: + """Validate that input contains required fields for all enrichments.""" return all( - all(item.get(field) is not None for field in required_fields) - for item in feedback_interactions + enrichment.question is not None + and enrichment.original_answer is not None + and enrichment.context is not None + and enrichment.feedback_text is not None + and enrichment.feedback_id is not None + and enrichment.interaction_id is not None + for enrichment in enrichments ) @@ -51,17 +49,15 @@ def _render_reaction_prompt( async def _generate_improved_answer_for_single_interaction( - feedback_interaction: Dict, retriever, reaction_prompt_location: str -) -> Optional[Dict]: - """Generate improved answer for a single feedback-interaction pair using structured retriever completion.""" + enrichment: FeedbackEnrichment, retriever, reaction_prompt_location: str +) -> Optional[FeedbackEnrichment]: + """Generate improved answer for a single enrichment using structured retriever completion.""" try: - question_text = feedback_interaction["question"] - original_answer_text = feedback_interaction["answer"] - context_text = feedback_interaction["context"] - feedback_text = feedback_interaction["feedback_text"] - query_text = _render_reaction_prompt( - question_text, context_text, original_answer_text, feedback_text + enrichment.question, + enrichment.context, + enrichment.original_answer, + enrichment.feedback_text, ) retrieved_context = await retriever.get_context(query_text) @@ -69,20 +65,18 @@ async def _generate_improved_answer_for_single_interaction( query=query_text, context=retrieved_context, response_model=ImprovedAnswerResponse, - max_iter=1, + max_iter=4, ) new_context_text = await retriever.resolve_edges_to_text(retrieved_context) if completion: - return { - **feedback_interaction, - "improved_answer": completion.answer, - "new_context": new_context_text, - "explanation": completion.explanation, - } + enrichment.improved_answer = completion.answer + enrichment.new_context = new_context_text + enrichment.explanation = completion.explanation + return enrichment else: logger.warning( - "Failed to get structured completion from retriever", question=question_text + "Failed to get structured completion from retriever", question=enrichment.question ) return None @@ -90,23 +84,23 @@ async def _generate_improved_answer_for_single_interaction( logger.error( "Failed to generate improved answer", error=str(exc), - question=feedback_interaction.get("question"), + question=enrichment.question, ) return None async def generate_improved_answers( - feedback_interactions: List[Dict], + enrichments: List[FeedbackEnrichment], retriever_name: str = "graph_completion_cot", top_k: int = 20, reaction_prompt_location: str = "feedback_reaction_prompt.txt", -) -> List[Dict]: +) -> List[FeedbackEnrichment]: """Generate improved answers using configurable retriever and LLM.""" - if not feedback_interactions: - logger.info("No feedback interactions provided; returning empty list") + if not enrichments: + logger.info("No enrichments provided; returning empty list") return [] - if not _validate_input_data(feedback_interactions): + if not _validate_input_data(enrichments): logger.error("Input data validation failed; missing required fields") return [] @@ -117,11 +111,11 @@ async def generate_improved_answers( system_prompt_path="answer_simple_question.txt", ) - improved_answers: List[Dict] = [] + improved_answers: List[FeedbackEnrichment] = [] - for feedback_interaction in feedback_interactions: + for enrichment in enrichments: result = await _generate_improved_answer_for_single_interaction( - feedback_interaction, retriever, reaction_prompt_location + enrichment, retriever, reaction_prompt_location ) if result: @@ -129,8 +123,8 @@ async def generate_improved_answers( else: logger.warning( "Failed to generate improved answer", - question=feedback_interaction.get("question"), - interaction_id=feedback_interaction.get("interaction_id"), + question=enrichment.question, + interaction_id=enrichment.interaction_id, ) logger.info("Generated improved answers", count=len(improved_answers)) diff --git a/cognee/tasks/feedback/link_enrichments_to_feedback.py b/cognee/tasks/feedback/link_enrichments_to_feedback.py new file mode 100644 index 000000000..d536bdc56 --- /dev/null +++ b/cognee/tasks/feedback/link_enrichments_to_feedback.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import List, Tuple +from uuid import UUID + +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.tasks.storage import index_graph_edges +from cognee.shared.logging_utils import get_logger + +from .models import FeedbackEnrichment + + +logger = get_logger("link_enrichments_to_feedback") + + +def _create_edge_tuple( + source_id: UUID, target_id: UUID, relationship_name: str +) -> Tuple[UUID, UUID, str, dict]: + """Create an edge tuple with proper properties structure.""" + return ( + source_id, + target_id, + relationship_name, + { + "relationship_name": relationship_name, + "source_node_id": source_id, + "target_node_id": target_id, + "ontology_valid": False, + }, + ) + + +async def link_enrichments_to_feedback( + enrichments: List[FeedbackEnrichment], +) -> List[FeedbackEnrichment]: + """Manually create edges from enrichments to original feedback/interaction nodes.""" + if not enrichments: + logger.info("No enrichments provided; returning empty list") + return [] + + relationships = [] + + for enrichment in enrichments: + enrichment_id = enrichment.id + feedback_id = enrichment.feedback_id + interaction_id = enrichment.interaction_id + + if enrichment_id and feedback_id: + enriches_feedback_edge = _create_edge_tuple( + enrichment_id, feedback_id, "enriches_feedback" + ) + relationships.append(enriches_feedback_edge) + + if enrichment_id and interaction_id: + improves_interaction_edge = _create_edge_tuple( + enrichment_id, interaction_id, "improves_interaction" + ) + relationships.append(improves_interaction_edge) + + if relationships: + graph_engine = await get_graph_engine() + await graph_engine.add_edges(relationships) + await index_graph_edges(relationships) + logger.info("Linking enrichments to feedback", edge_count=len(relationships)) + + logger.info("Linked enrichments", enrichment_count=len(enrichments)) + return enrichments diff --git a/cognee/tasks/feedback/models.py b/cognee/tasks/feedback/models.py index 6815c2de1..c334ec8c0 100644 --- a/cognee/tasks/feedback/models.py +++ b/cognee/tasks/feedback/models.py @@ -19,3 +19,8 @@ class FeedbackEnrichment(DataPoint): feedback_id: UUID interaction_id: UUID belongs_to_set: Optional[List[NodeSet]] = None + + context: str = "" + feedback_text: str = "" + new_context: str = "" + explanation: str = "" diff --git a/examples/python/feedback_enrichment_minimal_example.py b/examples/python/feedback_enrichment_minimal_example.py index c37c0fbdf..8e7f01c7d 100644 --- a/examples/python/feedback_enrichment_minimal_example.py +++ b/examples/python/feedback_enrichment_minimal_example.py @@ -10,6 +10,7 @@ from cognee.shared.data_models import KnowledgeGraph from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers from cognee.tasks.feedback.create_enrichments import create_enrichments +from cognee.tasks.feedback.link_enrichments_to_feedback import link_enrichments_to_feedback CONVERSATION = [ @@ -60,6 +61,7 @@ async def run_feedback_enrichment_memify(last_n: int = 5): Task(create_enrichments), Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}), Task(add_data_points, task_config={"batch_size": 10}), + Task(link_enrichments_to_feedback), ] await cognee.memify( extraction_tasks=extraction_tasks, @@ -70,9 +72,8 @@ async def run_feedback_enrichment_memify(last_n: int = 5): async def main(): - # await initialize_conversation_and_graph(CONVERSATION) - # is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?") - is_correct = False + await initialize_conversation_and_graph(CONVERSATION) + is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?") if not is_correct: await run_feedback_enrichment_memify(last_n=5)