feat: use datapoints only
This commit is contained in:
parent
8e580bd3d3
commit
590c3ad7ec
7 changed files with 184 additions and 160 deletions
13
cognee/tasks/feedback/__init__.py
Normal file
13
cognee/tasks/feedback/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from .extract_feedback_interactions import extract_feedback_interactions
|
||||
from .generate_improved_answers import generate_improved_answers
|
||||
from .create_enrichments import create_enrichments
|
||||
from .link_enrichments_to_feedback import link_enrichments_to_feedback
|
||||
from .models import FeedbackEnrichment
|
||||
|
||||
__all__ = [
|
||||
"extract_feedback_interactions",
|
||||
"generate_improved_answers",
|
||||
"create_enrichments",
|
||||
"link_enrichments_to_feedback",
|
||||
"FeedbackEnrichment",
|
||||
]
|
||||
|
|
@ -14,36 +14,19 @@ from .models import FeedbackEnrichment
|
|||
logger = get_logger("create_enrichments")
|
||||
|
||||
|
||||
def _validate_improved_answers(improved_answers: List[Dict]) -> bool:
|
||||
"""Validate that all items contain required fields for enrichment creation."""
|
||||
required_fields = [
|
||||
"question",
|
||||
"answer", # This is the original answer field from feedback_interaction
|
||||
"improved_answer",
|
||||
"new_context",
|
||||
"feedback_id",
|
||||
"interaction_id",
|
||||
]
|
||||
def _validate_enrichments(enrichments: List[FeedbackEnrichment]) -> bool:
|
||||
"""Validate that all enrichments contain required fields for completion."""
|
||||
return all(
|
||||
all(item.get(field) is not None for field in required_fields) for item in improved_answers
|
||||
enrichment.question is not None
|
||||
and enrichment.original_answer is not None
|
||||
and enrichment.improved_answer is not None
|
||||
and enrichment.new_context is not None
|
||||
and enrichment.feedback_id is not None
|
||||
and enrichment.interaction_id is not None
|
||||
for enrichment in enrichments
|
||||
)
|
||||
|
||||
|
||||
def _validate_uuid_fields(improved_answers: List[Dict]) -> bool:
|
||||
"""Validate that feedback_id and interaction_id are valid UUID objects."""
|
||||
try:
|
||||
for item in improved_answers:
|
||||
feedback_id = item.get("feedback_id")
|
||||
interaction_id = item.get("interaction_id")
|
||||
if not isinstance(feedback_id, type(feedback_id)) or not isinstance(
|
||||
interaction_id, type(interaction_id)
|
||||
):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def _generate_enrichment_report(
|
||||
question: str, improved_answer: str, new_context: str, report_prompt_location: str
|
||||
) -> str:
|
||||
|
|
@ -65,80 +48,37 @@ async def _generate_enrichment_report(
|
|||
return f"Educational content for: {question} - {improved_answer}"
|
||||
|
||||
|
||||
async def _create_enrichment_datapoint(
|
||||
improved_answer_item: Dict,
|
||||
report_text: str,
|
||||
nodeset: NodeSet,
|
||||
) -> Optional[FeedbackEnrichment]:
|
||||
"""Create a single FeedbackEnrichment DataPoint with proper ID and nodeset assignment."""
|
||||
try:
|
||||
question = improved_answer_item["question"]
|
||||
improved_answer = improved_answer_item["improved_answer"]
|
||||
|
||||
enrichment = FeedbackEnrichment(
|
||||
id=str(uuid5(NAMESPACE_OID, f"{question}_{improved_answer}")),
|
||||
text=report_text,
|
||||
question=question,
|
||||
original_answer=improved_answer_item["answer"], # Use "answer" field
|
||||
improved_answer=improved_answer,
|
||||
feedback_id=improved_answer_item["feedback_id"],
|
||||
interaction_id=improved_answer_item["interaction_id"],
|
||||
belongs_to_set=[nodeset],
|
||||
)
|
||||
|
||||
return enrichment
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to create enrichment datapoint",
|
||||
error=str(exc),
|
||||
question=improved_answer_item.get("question"),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def create_enrichments(
|
||||
improved_answers: List[Dict],
|
||||
enrichments: List[FeedbackEnrichment],
|
||||
report_prompt_location: str = "feedback_report_prompt.txt",
|
||||
) -> List[FeedbackEnrichment]:
|
||||
"""Create FeedbackEnrichment DataPoint instances from improved answers."""
|
||||
if not improved_answers:
|
||||
logger.info("No improved answers provided; returning empty list")
|
||||
"""Fill text and belongs_to_set fields of existing FeedbackEnrichment DataPoints."""
|
||||
if not enrichments:
|
||||
logger.info("No enrichments provided; returning empty list")
|
||||
return []
|
||||
|
||||
if not _validate_improved_answers(improved_answers):
|
||||
if not _validate_enrichments(enrichments):
|
||||
logger.error("Input validation failed; missing required fields")
|
||||
return []
|
||||
|
||||
if not _validate_uuid_fields(improved_answers):
|
||||
logger.error("UUID validation failed; invalid feedback_id or interaction_id")
|
||||
return []
|
||||
logger.info("Completing enrichments", count=len(enrichments))
|
||||
|
||||
logger.info("Creating enrichments", count=len(improved_answers))
|
||||
|
||||
# Create nodeset once for all enrichments
|
||||
nodeset = NodeSet(id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment")
|
||||
|
||||
enrichments: List[FeedbackEnrichment] = []
|
||||
|
||||
for improved_answer_item in improved_answers:
|
||||
question = improved_answer_item["question"]
|
||||
improved_answer = improved_answer_item["improved_answer"]
|
||||
new_context = improved_answer_item["new_context"]
|
||||
completed_enrichments: List[FeedbackEnrichment] = []
|
||||
|
||||
for enrichment in enrichments:
|
||||
report_text = await _generate_enrichment_report(
|
||||
question, improved_answer, new_context, report_prompt_location
|
||||
enrichment.question,
|
||||
enrichment.improved_answer,
|
||||
enrichment.new_context,
|
||||
report_prompt_location,
|
||||
)
|
||||
|
||||
enrichment = await _create_enrichment_datapoint(improved_answer_item, report_text, nodeset)
|
||||
enrichment.text = report_text
|
||||
enrichment.belongs_to_set = [nodeset]
|
||||
|
||||
if enrichment:
|
||||
enrichments.append(enrichment)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to create enrichment",
|
||||
question=question,
|
||||
interaction_id=improved_answer_item.get("interaction_id"),
|
||||
)
|
||||
completed_enrichments.append(enrichment)
|
||||
|
||||
logger.info("Created enrichments", successful=len(enrichments))
|
||||
return enrichments
|
||||
logger.info("Completed enrichments", successful=len(completed_enrichments))
|
||||
return completed_enrichments
|
||||
|
|
|
|||
|
|
@ -7,8 +7,10 @@ from cognee.infrastructure.llm import LLMGateway
|
|||
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
|
||||
from .utils import filter_negative_feedback
|
||||
from .models import FeedbackEnrichment
|
||||
|
||||
|
||||
logger = get_logger("extract_feedback_interactions")
|
||||
|
|
@ -49,11 +51,8 @@ def _match_feedback_nodes_to_interactions_by_edges(
|
|||
feedback_nodes: List, interaction_nodes: List, graph_edges: List
|
||||
) -> List[Tuple[Tuple, Tuple]]:
|
||||
"""Match feedback to interactions using gives_feedback_to edges."""
|
||||
# Build single lookup maps using normalized Cognee IDs
|
||||
interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes}
|
||||
feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes}
|
||||
|
||||
# Filter to only gives_feedback_to edges
|
||||
feedback_edges = [
|
||||
(source_id, target_id)
|
||||
for source_id, target_id, rel, _ in graph_edges
|
||||
|
|
@ -103,23 +102,22 @@ async def _generate_human_readable_context_summary(
|
|||
return raw_context_text or ""
|
||||
|
||||
|
||||
def _has_required_feedback_fields(record: Dict) -> bool:
|
||||
"""Validate required fields exist in the item dict."""
|
||||
required_fields = [
|
||||
"question",
|
||||
"answer",
|
||||
"context",
|
||||
"feedback_text",
|
||||
"feedback_id",
|
||||
"interaction_id",
|
||||
]
|
||||
return all(record.get(field_name) is not None for field_name in required_fields)
|
||||
def _has_required_feedback_fields(enrichment: FeedbackEnrichment) -> bool:
|
||||
"""Validate required fields exist in the FeedbackEnrichment DataPoint."""
|
||||
return (
|
||||
enrichment.question is not None
|
||||
and enrichment.original_answer is not None
|
||||
and enrichment.context is not None
|
||||
and enrichment.feedback_text is not None
|
||||
and enrichment.feedback_id is not None
|
||||
and enrichment.interaction_id is not None
|
||||
)
|
||||
|
||||
|
||||
async def _build_feedback_interaction_record(
|
||||
feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict
|
||||
) -> Optional[Dict]:
|
||||
"""Build a single feedback-interaction record with context summary."""
|
||||
) -> Optional[FeedbackEnrichment]:
|
||||
"""Build a single FeedbackEnrichment DataPoint with context summary."""
|
||||
try:
|
||||
question_text = interaction_props.get("question")
|
||||
original_answer_text = interaction_props.get("answer")
|
||||
|
|
@ -130,17 +128,23 @@ async def _build_feedback_interaction_record(
|
|||
question_text or "", raw_context_text
|
||||
)
|
||||
|
||||
feedback_interaction_record = {
|
||||
"question": question_text,
|
||||
"answer": original_answer_text,
|
||||
"context": context_summary_text,
|
||||
"feedback_text": feedback_text,
|
||||
"feedback_id": UUID(str(feedback_node_id)),
|
||||
"interaction_id": UUID(str(interaction_node_id)),
|
||||
}
|
||||
enrichment = FeedbackEnrichment(
|
||||
id=str(uuid5(NAMESPACE_OID, f"{question_text}_{interaction_node_id}")),
|
||||
text="",
|
||||
question=question_text,
|
||||
original_answer=original_answer_text,
|
||||
improved_answer="",
|
||||
feedback_id=UUID(str(feedback_node_id)),
|
||||
interaction_id=UUID(str(interaction_node_id)),
|
||||
belongs_to_set=None,
|
||||
context=context_summary_text,
|
||||
feedback_text=feedback_text,
|
||||
new_context="",
|
||||
explanation="",
|
||||
)
|
||||
|
||||
if _has_required_feedback_fields(feedback_interaction_record):
|
||||
return feedback_interaction_record
|
||||
if _has_required_feedback_fields(enrichment):
|
||||
return enrichment
|
||||
else:
|
||||
logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id))
|
||||
return None
|
||||
|
|
@ -151,9 +155,9 @@ async def _build_feedback_interaction_record(
|
|||
|
||||
async def _build_feedback_interaction_records(
|
||||
matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]],
|
||||
) -> List[Dict]:
|
||||
"""Build all feedback-interaction records from matched pairs."""
|
||||
feedback_interaction_records: List[Dict] = []
|
||||
) -> List[FeedbackEnrichment]:
|
||||
"""Build all FeedbackEnrichment DataPoints from matched pairs."""
|
||||
feedback_interaction_records: List[FeedbackEnrichment] = []
|
||||
for (feedback_node_id, feedback_props), (
|
||||
interaction_node_id,
|
||||
interaction_props,
|
||||
|
|
@ -168,8 +172,8 @@ async def _build_feedback_interaction_records(
|
|||
|
||||
async def extract_feedback_interactions(
|
||||
subgraphs: List, last_n: Optional[int] = None
|
||||
) -> List[Dict]:
|
||||
"""Extract negative feedback-interaction pairs; fetch internally and use last_n param for limiting."""
|
||||
) -> List[FeedbackEnrichment]:
|
||||
"""Extract negative feedback-interaction pairs and create FeedbackEnrichment DataPoints."""
|
||||
graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data()
|
||||
if not graph_nodes:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from cognee.modules.graph.utils import resolve_edges_to_text
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .utils import create_retriever
|
||||
from .models import FeedbackEnrichment
|
||||
|
||||
|
||||
class ImprovedAnswerResponse(BaseModel):
|
||||
|
|
@ -21,19 +22,16 @@ class ImprovedAnswerResponse(BaseModel):
|
|||
logger = get_logger("generate_improved_answers")
|
||||
|
||||
|
||||
def _validate_input_data(feedback_interactions: List[Dict]) -> bool:
|
||||
"""Validate that input contains required fields for all items."""
|
||||
required_fields = [
|
||||
"question",
|
||||
"answer",
|
||||
"context",
|
||||
"feedback_text",
|
||||
"feedback_id",
|
||||
"interaction_id",
|
||||
]
|
||||
def _validate_input_data(enrichments: List[FeedbackEnrichment]) -> bool:
|
||||
"""Validate that input contains required fields for all enrichments."""
|
||||
return all(
|
||||
all(item.get(field) is not None for field in required_fields)
|
||||
for item in feedback_interactions
|
||||
enrichment.question is not None
|
||||
and enrichment.original_answer is not None
|
||||
and enrichment.context is not None
|
||||
and enrichment.feedback_text is not None
|
||||
and enrichment.feedback_id is not None
|
||||
and enrichment.interaction_id is not None
|
||||
for enrichment in enrichments
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -51,17 +49,15 @@ def _render_reaction_prompt(
|
|||
|
||||
|
||||
async def _generate_improved_answer_for_single_interaction(
|
||||
feedback_interaction: Dict, retriever, reaction_prompt_location: str
|
||||
) -> Optional[Dict]:
|
||||
"""Generate improved answer for a single feedback-interaction pair using structured retriever completion."""
|
||||
enrichment: FeedbackEnrichment, retriever, reaction_prompt_location: str
|
||||
) -> Optional[FeedbackEnrichment]:
|
||||
"""Generate improved answer for a single enrichment using structured retriever completion."""
|
||||
try:
|
||||
question_text = feedback_interaction["question"]
|
||||
original_answer_text = feedback_interaction["answer"]
|
||||
context_text = feedback_interaction["context"]
|
||||
feedback_text = feedback_interaction["feedback_text"]
|
||||
|
||||
query_text = _render_reaction_prompt(
|
||||
question_text, context_text, original_answer_text, feedback_text
|
||||
enrichment.question,
|
||||
enrichment.context,
|
||||
enrichment.original_answer,
|
||||
enrichment.feedback_text,
|
||||
)
|
||||
|
||||
retrieved_context = await retriever.get_context(query_text)
|
||||
|
|
@ -69,20 +65,18 @@ async def _generate_improved_answer_for_single_interaction(
|
|||
query=query_text,
|
||||
context=retrieved_context,
|
||||
response_model=ImprovedAnswerResponse,
|
||||
max_iter=1,
|
||||
max_iter=4,
|
||||
)
|
||||
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
|
||||
|
||||
if completion:
|
||||
return {
|
||||
**feedback_interaction,
|
||||
"improved_answer": completion.answer,
|
||||
"new_context": new_context_text,
|
||||
"explanation": completion.explanation,
|
||||
}
|
||||
enrichment.improved_answer = completion.answer
|
||||
enrichment.new_context = new_context_text
|
||||
enrichment.explanation = completion.explanation
|
||||
return enrichment
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to get structured completion from retriever", question=question_text
|
||||
"Failed to get structured completion from retriever", question=enrichment.question
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
@ -90,23 +84,23 @@ async def _generate_improved_answer_for_single_interaction(
|
|||
logger.error(
|
||||
"Failed to generate improved answer",
|
||||
error=str(exc),
|
||||
question=feedback_interaction.get("question"),
|
||||
question=enrichment.question,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def generate_improved_answers(
|
||||
feedback_interactions: List[Dict],
|
||||
enrichments: List[FeedbackEnrichment],
|
||||
retriever_name: str = "graph_completion_cot",
|
||||
top_k: int = 20,
|
||||
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
|
||||
) -> List[Dict]:
|
||||
) -> List[FeedbackEnrichment]:
|
||||
"""Generate improved answers using configurable retriever and LLM."""
|
||||
if not feedback_interactions:
|
||||
logger.info("No feedback interactions provided; returning empty list")
|
||||
if not enrichments:
|
||||
logger.info("No enrichments provided; returning empty list")
|
||||
return []
|
||||
|
||||
if not _validate_input_data(feedback_interactions):
|
||||
if not _validate_input_data(enrichments):
|
||||
logger.error("Input data validation failed; missing required fields")
|
||||
return []
|
||||
|
||||
|
|
@ -117,11 +111,11 @@ async def generate_improved_answers(
|
|||
system_prompt_path="answer_simple_question.txt",
|
||||
)
|
||||
|
||||
improved_answers: List[Dict] = []
|
||||
improved_answers: List[FeedbackEnrichment] = []
|
||||
|
||||
for feedback_interaction in feedback_interactions:
|
||||
for enrichment in enrichments:
|
||||
result = await _generate_improved_answer_for_single_interaction(
|
||||
feedback_interaction, retriever, reaction_prompt_location
|
||||
enrichment, retriever, reaction_prompt_location
|
||||
)
|
||||
|
||||
if result:
|
||||
|
|
@ -129,8 +123,8 @@ async def generate_improved_answers(
|
|||
else:
|
||||
logger.warning(
|
||||
"Failed to generate improved answer",
|
||||
question=feedback_interaction.get("question"),
|
||||
interaction_id=feedback_interaction.get("interaction_id"),
|
||||
question=enrichment.question,
|
||||
interaction_id=enrichment.interaction_id,
|
||||
)
|
||||
|
||||
logger.info("Generated improved answers", count=len(improved_answers))
|
||||
|
|
|
|||
67
cognee/tasks/feedback/link_enrichments_to_feedback.py
Normal file
67
cognee/tasks/feedback/link_enrichments_to_feedback.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.tasks.storage import index_graph_edges
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .models import FeedbackEnrichment
|
||||
|
||||
|
||||
logger = get_logger("link_enrichments_to_feedback")
|
||||
|
||||
|
||||
def _create_edge_tuple(
|
||||
source_id: UUID, target_id: UUID, relationship_name: str
|
||||
) -> Tuple[UUID, UUID, str, dict]:
|
||||
"""Create an edge tuple with proper properties structure."""
|
||||
return (
|
||||
source_id,
|
||||
target_id,
|
||||
relationship_name,
|
||||
{
|
||||
"relationship_name": relationship_name,
|
||||
"source_node_id": source_id,
|
||||
"target_node_id": target_id,
|
||||
"ontology_valid": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def link_enrichments_to_feedback(
|
||||
enrichments: List[FeedbackEnrichment],
|
||||
) -> List[FeedbackEnrichment]:
|
||||
"""Manually create edges from enrichments to original feedback/interaction nodes."""
|
||||
if not enrichments:
|
||||
logger.info("No enrichments provided; returning empty list")
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
|
||||
for enrichment in enrichments:
|
||||
enrichment_id = enrichment.id
|
||||
feedback_id = enrichment.feedback_id
|
||||
interaction_id = enrichment.interaction_id
|
||||
|
||||
if enrichment_id and feedback_id:
|
||||
enriches_feedback_edge = _create_edge_tuple(
|
||||
enrichment_id, feedback_id, "enriches_feedback"
|
||||
)
|
||||
relationships.append(enriches_feedback_edge)
|
||||
|
||||
if enrichment_id and interaction_id:
|
||||
improves_interaction_edge = _create_edge_tuple(
|
||||
enrichment_id, interaction_id, "improves_interaction"
|
||||
)
|
||||
relationships.append(improves_interaction_edge)
|
||||
|
||||
if relationships:
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.add_edges(relationships)
|
||||
await index_graph_edges(relationships)
|
||||
logger.info("Linking enrichments to feedback", edge_count=len(relationships))
|
||||
|
||||
logger.info("Linked enrichments", enrichment_count=len(enrichments))
|
||||
return enrichments
|
||||
|
|
@ -19,3 +19,8 @@ class FeedbackEnrichment(DataPoint):
|
|||
feedback_id: UUID
|
||||
interaction_id: UUID
|
||||
belongs_to_set: Optional[List[NodeSet]] = None
|
||||
|
||||
context: str = ""
|
||||
feedback_text: str = ""
|
||||
new_context: str = ""
|
||||
explanation: str = ""
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from cognee.shared.data_models import KnowledgeGraph
|
|||
from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
|
||||
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
|
||||
from cognee.tasks.feedback.create_enrichments import create_enrichments
|
||||
from cognee.tasks.feedback.link_enrichments_to_feedback import link_enrichments_to_feedback
|
||||
|
||||
|
||||
CONVERSATION = [
|
||||
|
|
@ -60,6 +61,7 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
|
|||
Task(create_enrichments),
|
||||
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(link_enrichments_to_feedback),
|
||||
]
|
||||
await cognee.memify(
|
||||
extraction_tasks=extraction_tasks,
|
||||
|
|
@ -70,9 +72,8 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
|
|||
|
||||
|
||||
async def main():
|
||||
# await initialize_conversation_and_graph(CONVERSATION)
|
||||
# is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
|
||||
is_correct = False
|
||||
await initialize_conversation_and_graph(CONVERSATION)
|
||||
is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
|
||||
if not is_correct:
|
||||
await run_feedback_enrichment_memify(last_n=5)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue