feat: use datapoints only

This commit is contained in:
lxobr 2025-10-21 01:30:08 +02:00
parent 8e580bd3d3
commit 590c3ad7ec
7 changed files with 184 additions and 160 deletions

View file

@ -0,0 +1,13 @@
from .extract_feedback_interactions import extract_feedback_interactions
from .generate_improved_answers import generate_improved_answers
from .create_enrichments import create_enrichments
from .link_enrichments_to_feedback import link_enrichments_to_feedback
from .models import FeedbackEnrichment
__all__ = [
"extract_feedback_interactions",
"generate_improved_answers",
"create_enrichments",
"link_enrichments_to_feedback",
"FeedbackEnrichment",
]

View file

@ -14,36 +14,19 @@ from .models import FeedbackEnrichment
logger = get_logger("create_enrichments")
def _validate_improved_answers(improved_answers: List[Dict]) -> bool:
"""Validate that all items contain required fields for enrichment creation."""
required_fields = [
"question",
"answer", # This is the original answer field from feedback_interaction
"improved_answer",
"new_context",
"feedback_id",
"interaction_id",
]
def _validate_enrichments(enrichments: List[FeedbackEnrichment]) -> bool:
"""Validate that all enrichments contain required fields for completion."""
return all(
all(item.get(field) is not None for field in required_fields) for item in improved_answers
enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.improved_answer is not None
and enrichment.new_context is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
for enrichment in enrichments
)
def _validate_uuid_fields(improved_answers: List[Dict]) -> bool:
"""Validate that feedback_id and interaction_id are valid UUID objects."""
try:
for item in improved_answers:
feedback_id = item.get("feedback_id")
interaction_id = item.get("interaction_id")
if not isinstance(feedback_id, type(feedback_id)) or not isinstance(
interaction_id, type(interaction_id)
):
return False
return True
except Exception:
return False
async def _generate_enrichment_report(
question: str, improved_answer: str, new_context: str, report_prompt_location: str
) -> str:
@ -65,80 +48,37 @@ async def _generate_enrichment_report(
return f"Educational content for: {question} - {improved_answer}"
async def _create_enrichment_datapoint(
improved_answer_item: Dict,
report_text: str,
nodeset: NodeSet,
) -> Optional[FeedbackEnrichment]:
"""Create a single FeedbackEnrichment DataPoint with proper ID and nodeset assignment."""
try:
question = improved_answer_item["question"]
improved_answer = improved_answer_item["improved_answer"]
enrichment = FeedbackEnrichment(
id=str(uuid5(NAMESPACE_OID, f"{question}_{improved_answer}")),
text=report_text,
question=question,
original_answer=improved_answer_item["answer"], # Use "answer" field
improved_answer=improved_answer,
feedback_id=improved_answer_item["feedback_id"],
interaction_id=improved_answer_item["interaction_id"],
belongs_to_set=[nodeset],
)
return enrichment
except Exception as exc:
logger.error(
"Failed to create enrichment datapoint",
error=str(exc),
question=improved_answer_item.get("question"),
)
return None
async def create_enrichments(
improved_answers: List[Dict],
enrichments: List[FeedbackEnrichment],
report_prompt_location: str = "feedback_report_prompt.txt",
) -> List[FeedbackEnrichment]:
"""Create FeedbackEnrichment DataPoint instances from improved answers."""
if not improved_answers:
logger.info("No improved answers provided; returning empty list")
"""Fill text and belongs_to_set fields of existing FeedbackEnrichment DataPoints."""
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
if not _validate_improved_answers(improved_answers):
if not _validate_enrichments(enrichments):
logger.error("Input validation failed; missing required fields")
return []
if not _validate_uuid_fields(improved_answers):
logger.error("UUID validation failed; invalid feedback_id or interaction_id")
return []
logger.info("Completing enrichments", count=len(enrichments))
logger.info("Creating enrichments", count=len(improved_answers))
# Create nodeset once for all enrichments
nodeset = NodeSet(id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment")
enrichments: List[FeedbackEnrichment] = []
for improved_answer_item in improved_answers:
question = improved_answer_item["question"]
improved_answer = improved_answer_item["improved_answer"]
new_context = improved_answer_item["new_context"]
completed_enrichments: List[FeedbackEnrichment] = []
for enrichment in enrichments:
report_text = await _generate_enrichment_report(
question, improved_answer, new_context, report_prompt_location
enrichment.question,
enrichment.improved_answer,
enrichment.new_context,
report_prompt_location,
)
enrichment = await _create_enrichment_datapoint(improved_answer_item, report_text, nodeset)
enrichment.text = report_text
enrichment.belongs_to_set = [nodeset]
if enrichment:
enrichments.append(enrichment)
else:
logger.warning(
"Failed to create enrichment",
question=question,
interaction_id=improved_answer_item.get("interaction_id"),
)
completed_enrichments.append(enrichment)
logger.info("Created enrichments", successful=len(enrichments))
return enrichments
logger.info("Completed enrichments", successful=len(completed_enrichments))
return completed_enrichments

View file

@ -7,8 +7,10 @@ from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from uuid import uuid5, NAMESPACE_OID
from .utils import filter_negative_feedback
from .models import FeedbackEnrichment
logger = get_logger("extract_feedback_interactions")
@ -49,11 +51,8 @@ def _match_feedback_nodes_to_interactions_by_edges(
feedback_nodes: List, interaction_nodes: List, graph_edges: List
) -> List[Tuple[Tuple, Tuple]]:
"""Match feedback to interactions using gives_feedback_to edges."""
# Build single lookup maps using normalized Cognee IDs
interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes}
feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes}
# Filter to only gives_feedback_to edges
feedback_edges = [
(source_id, target_id)
for source_id, target_id, rel, _ in graph_edges
@ -103,23 +102,22 @@ async def _generate_human_readable_context_summary(
return raw_context_text or ""
def _has_required_feedback_fields(record: Dict) -> bool:
"""Validate required fields exist in the item dict."""
required_fields = [
"question",
"answer",
"context",
"feedback_text",
"feedback_id",
"interaction_id",
]
return all(record.get(field_name) is not None for field_name in required_fields)
def _has_required_feedback_fields(enrichment: FeedbackEnrichment) -> bool:
"""Validate required fields exist in the FeedbackEnrichment DataPoint."""
return (
enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.context is not None
and enrichment.feedback_text is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
)
async def _build_feedback_interaction_record(
feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict
) -> Optional[Dict]:
"""Build a single feedback-interaction record with context summary."""
) -> Optional[FeedbackEnrichment]:
"""Build a single FeedbackEnrichment DataPoint with context summary."""
try:
question_text = interaction_props.get("question")
original_answer_text = interaction_props.get("answer")
@ -130,17 +128,23 @@ async def _build_feedback_interaction_record(
question_text or "", raw_context_text
)
feedback_interaction_record = {
"question": question_text,
"answer": original_answer_text,
"context": context_summary_text,
"feedback_text": feedback_text,
"feedback_id": UUID(str(feedback_node_id)),
"interaction_id": UUID(str(interaction_node_id)),
}
enrichment = FeedbackEnrichment(
id=str(uuid5(NAMESPACE_OID, f"{question_text}_{interaction_node_id}")),
text="",
question=question_text,
original_answer=original_answer_text,
improved_answer="",
feedback_id=UUID(str(feedback_node_id)),
interaction_id=UUID(str(interaction_node_id)),
belongs_to_set=None,
context=context_summary_text,
feedback_text=feedback_text,
new_context="",
explanation="",
)
if _has_required_feedback_fields(feedback_interaction_record):
return feedback_interaction_record
if _has_required_feedback_fields(enrichment):
return enrichment
else:
logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id))
return None
@ -151,9 +155,9 @@ async def _build_feedback_interaction_record(
async def _build_feedback_interaction_records(
matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]],
) -> List[Dict]:
"""Build all feedback-interaction records from matched pairs."""
feedback_interaction_records: List[Dict] = []
) -> List[FeedbackEnrichment]:
"""Build all FeedbackEnrichment DataPoints from matched pairs."""
feedback_interaction_records: List[FeedbackEnrichment] = []
for (feedback_node_id, feedback_props), (
interaction_node_id,
interaction_props,
@ -168,8 +172,8 @@ async def _build_feedback_interaction_records(
async def extract_feedback_interactions(
subgraphs: List, last_n: Optional[int] = None
) -> List[Dict]:
"""Extract negative feedback-interaction pairs; fetch internally and use last_n param for limiting."""
) -> List[FeedbackEnrichment]:
"""Extract negative feedback-interaction pairs and create FeedbackEnrichment DataPoints."""
graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data()
if not graph_nodes:
return []

View file

@ -9,6 +9,7 @@ from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.shared.logging_utils import get_logger
from .utils import create_retriever
from .models import FeedbackEnrichment
class ImprovedAnswerResponse(BaseModel):
@ -21,19 +22,16 @@ class ImprovedAnswerResponse(BaseModel):
logger = get_logger("generate_improved_answers")
def _validate_input_data(feedback_interactions: List[Dict]) -> bool:
"""Validate that input contains required fields for all items."""
required_fields = [
"question",
"answer",
"context",
"feedback_text",
"feedback_id",
"interaction_id",
]
def _validate_input_data(enrichments: List[FeedbackEnrichment]) -> bool:
"""Validate that input contains required fields for all enrichments."""
return all(
all(item.get(field) is not None for field in required_fields)
for item in feedback_interactions
enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.context is not None
and enrichment.feedback_text is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
for enrichment in enrichments
)
@ -51,17 +49,15 @@ def _render_reaction_prompt(
async def _generate_improved_answer_for_single_interaction(
feedback_interaction: Dict, retriever, reaction_prompt_location: str
) -> Optional[Dict]:
"""Generate improved answer for a single feedback-interaction pair using structured retriever completion."""
enrichment: FeedbackEnrichment, retriever, reaction_prompt_location: str
) -> Optional[FeedbackEnrichment]:
"""Generate improved answer for a single enrichment using structured retriever completion."""
try:
question_text = feedback_interaction["question"]
original_answer_text = feedback_interaction["answer"]
context_text = feedback_interaction["context"]
feedback_text = feedback_interaction["feedback_text"]
query_text = _render_reaction_prompt(
question_text, context_text, original_answer_text, feedback_text
enrichment.question,
enrichment.context,
enrichment.original_answer,
enrichment.feedback_text,
)
retrieved_context = await retriever.get_context(query_text)
@ -69,20 +65,18 @@ async def _generate_improved_answer_for_single_interaction(
query=query_text,
context=retrieved_context,
response_model=ImprovedAnswerResponse,
max_iter=1,
max_iter=4,
)
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
if completion:
return {
**feedback_interaction,
"improved_answer": completion.answer,
"new_context": new_context_text,
"explanation": completion.explanation,
}
enrichment.improved_answer = completion.answer
enrichment.new_context = new_context_text
enrichment.explanation = completion.explanation
return enrichment
else:
logger.warning(
"Failed to get structured completion from retriever", question=question_text
"Failed to get structured completion from retriever", question=enrichment.question
)
return None
@ -90,23 +84,23 @@ async def _generate_improved_answer_for_single_interaction(
logger.error(
"Failed to generate improved answer",
error=str(exc),
question=feedback_interaction.get("question"),
question=enrichment.question,
)
return None
async def generate_improved_answers(
feedback_interactions: List[Dict],
enrichments: List[FeedbackEnrichment],
retriever_name: str = "graph_completion_cot",
top_k: int = 20,
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
) -> List[Dict]:
) -> List[FeedbackEnrichment]:
"""Generate improved answers using configurable retriever and LLM."""
if not feedback_interactions:
logger.info("No feedback interactions provided; returning empty list")
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
if not _validate_input_data(feedback_interactions):
if not _validate_input_data(enrichments):
logger.error("Input data validation failed; missing required fields")
return []
@ -117,11 +111,11 @@ async def generate_improved_answers(
system_prompt_path="answer_simple_question.txt",
)
improved_answers: List[Dict] = []
improved_answers: List[FeedbackEnrichment] = []
for feedback_interaction in feedback_interactions:
for enrichment in enrichments:
result = await _generate_improved_answer_for_single_interaction(
feedback_interaction, retriever, reaction_prompt_location
enrichment, retriever, reaction_prompt_location
)
if result:
@ -129,8 +123,8 @@ async def generate_improved_answers(
else:
logger.warning(
"Failed to generate improved answer",
question=feedback_interaction.get("question"),
interaction_id=feedback_interaction.get("interaction_id"),
question=enrichment.question,
interaction_id=enrichment.interaction_id,
)
logger.info("Generated improved answers", count=len(improved_answers))

View file

@ -0,0 +1,67 @@
from __future__ import annotations
from typing import List, Tuple
from uuid import UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.tasks.storage import index_graph_edges
from cognee.shared.logging_utils import get_logger
from .models import FeedbackEnrichment
logger = get_logger("link_enrichments_to_feedback")
def _create_edge_tuple(
source_id: UUID, target_id: UUID, relationship_name: str
) -> Tuple[UUID, UUID, str, dict]:
"""Create an edge tuple with proper properties structure."""
return (
source_id,
target_id,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id,
"ontology_valid": False,
},
)
async def link_enrichments_to_feedback(
enrichments: List[FeedbackEnrichment],
) -> List[FeedbackEnrichment]:
"""Manually create edges from enrichments to original feedback/interaction nodes."""
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
relationships = []
for enrichment in enrichments:
enrichment_id = enrichment.id
feedback_id = enrichment.feedback_id
interaction_id = enrichment.interaction_id
if enrichment_id and feedback_id:
enriches_feedback_edge = _create_edge_tuple(
enrichment_id, feedback_id, "enriches_feedback"
)
relationships.append(enriches_feedback_edge)
if enrichment_id and interaction_id:
improves_interaction_edge = _create_edge_tuple(
enrichment_id, interaction_id, "improves_interaction"
)
relationships.append(improves_interaction_edge)
if relationships:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
await index_graph_edges(relationships)
logger.info("Linking enrichments to feedback", edge_count=len(relationships))
logger.info("Linked enrichments", enrichment_count=len(enrichments))
return enrichments

View file

@ -19,3 +19,8 @@ class FeedbackEnrichment(DataPoint):
feedback_id: UUID
interaction_id: UUID
belongs_to_set: Optional[List[NodeSet]] = None
context: str = ""
feedback_text: str = ""
new_context: str = ""
explanation: str = ""

View file

@ -10,6 +10,7 @@ from cognee.shared.data_models import KnowledgeGraph
from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
from cognee.tasks.feedback.create_enrichments import create_enrichments
from cognee.tasks.feedback.link_enrichments_to_feedback import link_enrichments_to_feedback
CONVERSATION = [
@ -60,6 +61,7 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
Task(create_enrichments),
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}),
Task(add_data_points, task_config={"batch_size": 10}),
Task(link_enrichments_to_feedback),
]
await cognee.memify(
extraction_tasks=extraction_tasks,
@ -70,9 +72,8 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
async def main():
# await initialize_conversation_and_graph(CONVERSATION)
# is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
is_correct = False
await initialize_conversation_and_graph(CONVERSATION)
is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
if not is_correct:
await run_feedback_enrichment_memify(last_n=5)