feat: use datapoints only

This commit is contained in:
lxobr 2025-10-21 01:30:08 +02:00
parent 8e580bd3d3
commit 590c3ad7ec
7 changed files with 184 additions and 160 deletions

View file

@ -0,0 +1,13 @@
from .extract_feedback_interactions import extract_feedback_interactions
from .generate_improved_answers import generate_improved_answers
from .create_enrichments import create_enrichments
from .link_enrichments_to_feedback import link_enrichments_to_feedback
from .models import FeedbackEnrichment
__all__ = [
"extract_feedback_interactions",
"generate_improved_answers",
"create_enrichments",
"link_enrichments_to_feedback",
"FeedbackEnrichment",
]

View file

@ -14,36 +14,19 @@ from .models import FeedbackEnrichment
logger = get_logger("create_enrichments") logger = get_logger("create_enrichments")
def _validate_improved_answers(improved_answers: List[Dict]) -> bool: def _validate_enrichments(enrichments: List[FeedbackEnrichment]) -> bool:
"""Validate that all items contain required fields for enrichment creation.""" """Validate that all enrichments contain required fields for completion."""
required_fields = [
"question",
"answer", # This is the original answer field from feedback_interaction
"improved_answer",
"new_context",
"feedback_id",
"interaction_id",
]
return all( return all(
all(item.get(field) is not None for field in required_fields) for item in improved_answers enrichment.question is not None
and enrichment.original_answer is not None
and enrichment.improved_answer is not None
and enrichment.new_context is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
for enrichment in enrichments
) )
def _validate_uuid_fields(improved_answers: List[Dict]) -> bool:
"""Validate that feedback_id and interaction_id are valid UUID objects."""
try:
for item in improved_answers:
feedback_id = item.get("feedback_id")
interaction_id = item.get("interaction_id")
if not isinstance(feedback_id, type(feedback_id)) or not isinstance(
interaction_id, type(interaction_id)
):
return False
return True
except Exception:
return False
async def _generate_enrichment_report( async def _generate_enrichment_report(
question: str, improved_answer: str, new_context: str, report_prompt_location: str question: str, improved_answer: str, new_context: str, report_prompt_location: str
) -> str: ) -> str:
@ -65,80 +48,37 @@ async def _generate_enrichment_report(
return f"Educational content for: {question} - {improved_answer}" return f"Educational content for: {question} - {improved_answer}"
async def _create_enrichment_datapoint(
improved_answer_item: Dict,
report_text: str,
nodeset: NodeSet,
) -> Optional[FeedbackEnrichment]:
"""Create a single FeedbackEnrichment DataPoint with proper ID and nodeset assignment."""
try:
question = improved_answer_item["question"]
improved_answer = improved_answer_item["improved_answer"]
enrichment = FeedbackEnrichment(
id=str(uuid5(NAMESPACE_OID, f"{question}_{improved_answer}")),
text=report_text,
question=question,
original_answer=improved_answer_item["answer"], # Use "answer" field
improved_answer=improved_answer,
feedback_id=improved_answer_item["feedback_id"],
interaction_id=improved_answer_item["interaction_id"],
belongs_to_set=[nodeset],
)
return enrichment
except Exception as exc:
logger.error(
"Failed to create enrichment datapoint",
error=str(exc),
question=improved_answer_item.get("question"),
)
return None
async def create_enrichments( async def create_enrichments(
improved_answers: List[Dict], enrichments: List[FeedbackEnrichment],
report_prompt_location: str = "feedback_report_prompt.txt", report_prompt_location: str = "feedback_report_prompt.txt",
) -> List[FeedbackEnrichment]: ) -> List[FeedbackEnrichment]:
"""Create FeedbackEnrichment DataPoint instances from improved answers.""" """Fill text and belongs_to_set fields of existing FeedbackEnrichment DataPoints."""
if not improved_answers: if not enrichments:
logger.info("No improved answers provided; returning empty list") logger.info("No enrichments provided; returning empty list")
return [] return []
if not _validate_improved_answers(improved_answers): if not _validate_enrichments(enrichments):
logger.error("Input validation failed; missing required fields") logger.error("Input validation failed; missing required fields")
return [] return []
if not _validate_uuid_fields(improved_answers): logger.info("Completing enrichments", count=len(enrichments))
logger.error("UUID validation failed; invalid feedback_id or interaction_id")
return []
logger.info("Creating enrichments", count=len(improved_answers))
# Create nodeset once for all enrichments
nodeset = NodeSet(id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment") nodeset = NodeSet(id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment")
enrichments: List[FeedbackEnrichment] = [] completed_enrichments: List[FeedbackEnrichment] = []
for improved_answer_item in improved_answers:
question = improved_answer_item["question"]
improved_answer = improved_answer_item["improved_answer"]
new_context = improved_answer_item["new_context"]
for enrichment in enrichments:
report_text = await _generate_enrichment_report( report_text = await _generate_enrichment_report(
question, improved_answer, new_context, report_prompt_location enrichment.question,
enrichment.improved_answer,
enrichment.new_context,
report_prompt_location,
) )
enrichment = await _create_enrichment_datapoint(improved_answer_item, report_text, nodeset) enrichment.text = report_text
enrichment.belongs_to_set = [nodeset]
if enrichment: completed_enrichments.append(enrichment)
enrichments.append(enrichment)
else:
logger.warning(
"Failed to create enrichment",
question=question,
interaction_id=improved_answer_item.get("interaction_id"),
)
logger.info("Created enrichments", successful=len(enrichments)) logger.info("Completed enrichments", successful=len(completed_enrichments))
return enrichments return completed_enrichments

View file

@ -7,8 +7,10 @@ from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from uuid import uuid5, NAMESPACE_OID
from .utils import filter_negative_feedback from .utils import filter_negative_feedback
from .models import FeedbackEnrichment
logger = get_logger("extract_feedback_interactions") logger = get_logger("extract_feedback_interactions")
@ -49,11 +51,8 @@ def _match_feedback_nodes_to_interactions_by_edges(
feedback_nodes: List, interaction_nodes: List, graph_edges: List feedback_nodes: List, interaction_nodes: List, graph_edges: List
) -> List[Tuple[Tuple, Tuple]]: ) -> List[Tuple[Tuple, Tuple]]:
"""Match feedback to interactions using gives_feedback_to edges.""" """Match feedback to interactions using gives_feedback_to edges."""
# Build single lookup maps using normalized Cognee IDs
interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes} interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes}
feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes} feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes}
# Filter to only gives_feedback_to edges
feedback_edges = [ feedback_edges = [
(source_id, target_id) (source_id, target_id)
for source_id, target_id, rel, _ in graph_edges for source_id, target_id, rel, _ in graph_edges
@ -103,23 +102,22 @@ async def _generate_human_readable_context_summary(
return raw_context_text or "" return raw_context_text or ""
def _has_required_feedback_fields(record: Dict) -> bool: def _has_required_feedback_fields(enrichment: FeedbackEnrichment) -> bool:
"""Validate required fields exist in the item dict.""" """Validate required fields exist in the FeedbackEnrichment DataPoint."""
required_fields = [ return (
"question", enrichment.question is not None
"answer", and enrichment.original_answer is not None
"context", and enrichment.context is not None
"feedback_text", and enrichment.feedback_text is not None
"feedback_id", and enrichment.feedback_id is not None
"interaction_id", and enrichment.interaction_id is not None
] )
return all(record.get(field_name) is not None for field_name in required_fields)
async def _build_feedback_interaction_record( async def _build_feedback_interaction_record(
feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict
) -> Optional[Dict]: ) -> Optional[FeedbackEnrichment]:
"""Build a single feedback-interaction record with context summary.""" """Build a single FeedbackEnrichment DataPoint with context summary."""
try: try:
question_text = interaction_props.get("question") question_text = interaction_props.get("question")
original_answer_text = interaction_props.get("answer") original_answer_text = interaction_props.get("answer")
@ -130,17 +128,23 @@ async def _build_feedback_interaction_record(
question_text or "", raw_context_text question_text or "", raw_context_text
) )
feedback_interaction_record = { enrichment = FeedbackEnrichment(
"question": question_text, id=str(uuid5(NAMESPACE_OID, f"{question_text}_{interaction_node_id}")),
"answer": original_answer_text, text="",
"context": context_summary_text, question=question_text,
"feedback_text": feedback_text, original_answer=original_answer_text,
"feedback_id": UUID(str(feedback_node_id)), improved_answer="",
"interaction_id": UUID(str(interaction_node_id)), feedback_id=UUID(str(feedback_node_id)),
} interaction_id=UUID(str(interaction_node_id)),
belongs_to_set=None,
context=context_summary_text,
feedback_text=feedback_text,
new_context="",
explanation="",
)
if _has_required_feedback_fields(feedback_interaction_record): if _has_required_feedback_fields(enrichment):
return feedback_interaction_record return enrichment
else: else:
logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id)) logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id))
return None return None
@ -151,9 +155,9 @@ async def _build_feedback_interaction_record(
async def _build_feedback_interaction_records( async def _build_feedback_interaction_records(
matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]], matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]],
) -> List[Dict]: ) -> List[FeedbackEnrichment]:
"""Build all feedback-interaction records from matched pairs.""" """Build all FeedbackEnrichment DataPoints from matched pairs."""
feedback_interaction_records: List[Dict] = [] feedback_interaction_records: List[FeedbackEnrichment] = []
for (feedback_node_id, feedback_props), ( for (feedback_node_id, feedback_props), (
interaction_node_id, interaction_node_id,
interaction_props, interaction_props,
@ -168,8 +172,8 @@ async def _build_feedback_interaction_records(
async def extract_feedback_interactions( async def extract_feedback_interactions(
subgraphs: List, last_n: Optional[int] = None subgraphs: List, last_n: Optional[int] = None
) -> List[Dict]: ) -> List[FeedbackEnrichment]:
"""Extract negative feedback-interaction pairs; fetch internally and use last_n param for limiting.""" """Extract negative feedback-interaction pairs and create FeedbackEnrichment DataPoints."""
graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data() graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data()
if not graph_nodes: if not graph_nodes:
return [] return []

View file

@ -9,6 +9,7 @@ from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from .utils import create_retriever from .utils import create_retriever
from .models import FeedbackEnrichment
class ImprovedAnswerResponse(BaseModel): class ImprovedAnswerResponse(BaseModel):
@ -21,19 +22,16 @@ class ImprovedAnswerResponse(BaseModel):
logger = get_logger("generate_improved_answers") logger = get_logger("generate_improved_answers")
def _validate_input_data(feedback_interactions: List[Dict]) -> bool: def _validate_input_data(enrichments: List[FeedbackEnrichment]) -> bool:
"""Validate that input contains required fields for all items.""" """Validate that input contains required fields for all enrichments."""
required_fields = [
"question",
"answer",
"context",
"feedback_text",
"feedback_id",
"interaction_id",
]
return all( return all(
all(item.get(field) is not None for field in required_fields) enrichment.question is not None
for item in feedback_interactions and enrichment.original_answer is not None
and enrichment.context is not None
and enrichment.feedback_text is not None
and enrichment.feedback_id is not None
and enrichment.interaction_id is not None
for enrichment in enrichments
) )
@ -51,17 +49,15 @@ def _render_reaction_prompt(
async def _generate_improved_answer_for_single_interaction( async def _generate_improved_answer_for_single_interaction(
feedback_interaction: Dict, retriever, reaction_prompt_location: str enrichment: FeedbackEnrichment, retriever, reaction_prompt_location: str
) -> Optional[Dict]: ) -> Optional[FeedbackEnrichment]:
"""Generate improved answer for a single feedback-interaction pair using structured retriever completion.""" """Generate improved answer for a single enrichment using structured retriever completion."""
try: try:
question_text = feedback_interaction["question"]
original_answer_text = feedback_interaction["answer"]
context_text = feedback_interaction["context"]
feedback_text = feedback_interaction["feedback_text"]
query_text = _render_reaction_prompt( query_text = _render_reaction_prompt(
question_text, context_text, original_answer_text, feedback_text enrichment.question,
enrichment.context,
enrichment.original_answer,
enrichment.feedback_text,
) )
retrieved_context = await retriever.get_context(query_text) retrieved_context = await retriever.get_context(query_text)
@ -69,20 +65,18 @@ async def _generate_improved_answer_for_single_interaction(
query=query_text, query=query_text,
context=retrieved_context, context=retrieved_context,
response_model=ImprovedAnswerResponse, response_model=ImprovedAnswerResponse,
max_iter=1, max_iter=4,
) )
new_context_text = await retriever.resolve_edges_to_text(retrieved_context) new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
if completion: if completion:
return { enrichment.improved_answer = completion.answer
**feedback_interaction, enrichment.new_context = new_context_text
"improved_answer": completion.answer, enrichment.explanation = completion.explanation
"new_context": new_context_text, return enrichment
"explanation": completion.explanation,
}
else: else:
logger.warning( logger.warning(
"Failed to get structured completion from retriever", question=question_text "Failed to get structured completion from retriever", question=enrichment.question
) )
return None return None
@ -90,23 +84,23 @@ async def _generate_improved_answer_for_single_interaction(
logger.error( logger.error(
"Failed to generate improved answer", "Failed to generate improved answer",
error=str(exc), error=str(exc),
question=feedback_interaction.get("question"), question=enrichment.question,
) )
return None return None
async def generate_improved_answers( async def generate_improved_answers(
feedback_interactions: List[Dict], enrichments: List[FeedbackEnrichment],
retriever_name: str = "graph_completion_cot", retriever_name: str = "graph_completion_cot",
top_k: int = 20, top_k: int = 20,
reaction_prompt_location: str = "feedback_reaction_prompt.txt", reaction_prompt_location: str = "feedback_reaction_prompt.txt",
) -> List[Dict]: ) -> List[FeedbackEnrichment]:
"""Generate improved answers using configurable retriever and LLM.""" """Generate improved answers using configurable retriever and LLM."""
if not feedback_interactions: if not enrichments:
logger.info("No feedback interactions provided; returning empty list") logger.info("No enrichments provided; returning empty list")
return [] return []
if not _validate_input_data(feedback_interactions): if not _validate_input_data(enrichments):
logger.error("Input data validation failed; missing required fields") logger.error("Input data validation failed; missing required fields")
return [] return []
@ -117,11 +111,11 @@ async def generate_improved_answers(
system_prompt_path="answer_simple_question.txt", system_prompt_path="answer_simple_question.txt",
) )
improved_answers: List[Dict] = [] improved_answers: List[FeedbackEnrichment] = []
for feedback_interaction in feedback_interactions: for enrichment in enrichments:
result = await _generate_improved_answer_for_single_interaction( result = await _generate_improved_answer_for_single_interaction(
feedback_interaction, retriever, reaction_prompt_location enrichment, retriever, reaction_prompt_location
) )
if result: if result:
@ -129,8 +123,8 @@ async def generate_improved_answers(
else: else:
logger.warning( logger.warning(
"Failed to generate improved answer", "Failed to generate improved answer",
question=feedback_interaction.get("question"), question=enrichment.question,
interaction_id=feedback_interaction.get("interaction_id"), interaction_id=enrichment.interaction_id,
) )
logger.info("Generated improved answers", count=len(improved_answers)) logger.info("Generated improved answers", count=len(improved_answers))

View file

@ -0,0 +1,67 @@
from __future__ import annotations
from typing import List, Tuple
from uuid import UUID
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.tasks.storage import index_graph_edges
from cognee.shared.logging_utils import get_logger
from .models import FeedbackEnrichment
logger = get_logger("link_enrichments_to_feedback")
def _create_edge_tuple(
source_id: UUID, target_id: UUID, relationship_name: str
) -> Tuple[UUID, UUID, str, dict]:
"""Create an edge tuple with proper properties structure."""
return (
source_id,
target_id,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id,
"ontology_valid": False,
},
)
async def link_enrichments_to_feedback(
enrichments: List[FeedbackEnrichment],
) -> List[FeedbackEnrichment]:
"""Manually create edges from enrichments to original feedback/interaction nodes."""
if not enrichments:
logger.info("No enrichments provided; returning empty list")
return []
relationships = []
for enrichment in enrichments:
enrichment_id = enrichment.id
feedback_id = enrichment.feedback_id
interaction_id = enrichment.interaction_id
if enrichment_id and feedback_id:
enriches_feedback_edge = _create_edge_tuple(
enrichment_id, feedback_id, "enriches_feedback"
)
relationships.append(enriches_feedback_edge)
if enrichment_id and interaction_id:
improves_interaction_edge = _create_edge_tuple(
enrichment_id, interaction_id, "improves_interaction"
)
relationships.append(improves_interaction_edge)
if relationships:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)
await index_graph_edges(relationships)
logger.info("Linking enrichments to feedback", edge_count=len(relationships))
logger.info("Linked enrichments", enrichment_count=len(enrichments))
return enrichments

View file

@ -19,3 +19,8 @@ class FeedbackEnrichment(DataPoint):
feedback_id: UUID feedback_id: UUID
interaction_id: UUID interaction_id: UUID
belongs_to_set: Optional[List[NodeSet]] = None belongs_to_set: Optional[List[NodeSet]] = None
context: str = ""
feedback_text: str = ""
new_context: str = ""
explanation: str = ""

View file

@ -10,6 +10,7 @@ from cognee.shared.data_models import KnowledgeGraph
from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
from cognee.tasks.feedback.create_enrichments import create_enrichments from cognee.tasks.feedback.create_enrichments import create_enrichments
from cognee.tasks.feedback.link_enrichments_to_feedback import link_enrichments_to_feedback
CONVERSATION = [ CONVERSATION = [
@ -60,6 +61,7 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
Task(create_enrichments), Task(create_enrichments),
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}), Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}),
Task(add_data_points, task_config={"batch_size": 10}), Task(add_data_points, task_config={"batch_size": 10}),
Task(link_enrichments_to_feedback),
] ]
await cognee.memify( await cognee.memify(
extraction_tasks=extraction_tasks, extraction_tasks=extraction_tasks,
@ -70,9 +72,8 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
async def main(): async def main():
# await initialize_conversation_and_graph(CONVERSATION) await initialize_conversation_and_graph(CONVERSATION)
# is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?") is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
is_correct = False
if not is_correct: if not is_correct:
await run_feedback_enrichment_memify(last_n=5) await run_feedback_enrichment_memify(last_n=5)