feat: extract feedback interactions
This commit is contained in:
parent
44ec814256
commit
78fca9feb7
2 changed files with 273 additions and 0 deletions
199
cognee/tasks/feedback/extract_feedback_interactions.py
Normal file
199
cognee/tasks/feedback/extract_feedback_interactions.py
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
|
||||||
|
from .utils import filter_negative_feedback
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("extract_feedback_interactions")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_normalized_id(node_id, props) -> str:
|
||||||
|
"""Return Cognee node id preference: props.id → props.node_id → raw node_id."""
|
||||||
|
return str(props.get("id") or props.get("node_id") or node_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_feedback_and_interaction_graph_data() -> Tuple[List, List]:
|
||||||
|
"""Fetch feedback and interaction nodes with edges from graph engine."""
|
||||||
|
try:
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
attribute_filters = [{"type": ["CogneeUserFeedback", "CogneeUserInteraction"]}]
|
||||||
|
return await graph_engine.get_filtered_graph_data(attribute_filters)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.error("Failed to fetch filtered graph data", error=str(exc))
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
|
def _separate_feedback_and_interaction_nodes(graph_nodes: List) -> Tuple[List, List]:
|
||||||
|
"""Split nodes into feedback and interaction groups by type field."""
|
||||||
|
feedback_nodes = [
|
||||||
|
(_get_normalized_id(node_id, props), props)
|
||||||
|
for node_id, props in graph_nodes
|
||||||
|
if props.get("type") == "CogneeUserFeedback"
|
||||||
|
]
|
||||||
|
interaction_nodes = [
|
||||||
|
(_get_normalized_id(node_id, props), props)
|
||||||
|
for node_id, props in graph_nodes
|
||||||
|
if props.get("type") == "CogneeUserInteraction"
|
||||||
|
]
|
||||||
|
return feedback_nodes, interaction_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def _match_feedback_nodes_to_interactions_by_edges(
|
||||||
|
feedback_nodes: List, interaction_nodes: List, graph_edges: List
|
||||||
|
) -> List[Tuple[Tuple, Tuple]]:
|
||||||
|
"""Match feedback to interactions using gives_feedback_to edges."""
|
||||||
|
# Build single lookup maps using normalized Cognee IDs
|
||||||
|
interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes}
|
||||||
|
feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes}
|
||||||
|
|
||||||
|
# Filter to only gives_feedback_to edges
|
||||||
|
feedback_edges = [
|
||||||
|
(source_id, target_id)
|
||||||
|
for source_id, target_id, rel, _ in graph_edges
|
||||||
|
if rel == "gives_feedback_to"
|
||||||
|
]
|
||||||
|
|
||||||
|
feedback_interaction_pairs: List[Tuple[Tuple, Tuple]] = []
|
||||||
|
for source_id, target_id in feedback_edges:
|
||||||
|
source_id_str, target_id_str = str(source_id), str(target_id)
|
||||||
|
|
||||||
|
feedback_node = feedback_by_id.get(source_id_str)
|
||||||
|
interaction_node = interaction_by_id.get(target_id_str)
|
||||||
|
|
||||||
|
if feedback_node and interaction_node:
|
||||||
|
feedback_interaction_pairs.append((feedback_node, interaction_node))
|
||||||
|
|
||||||
|
return feedback_interaction_pairs
|
||||||
|
|
||||||
|
|
||||||
|
def _sort_pairs_by_recency_and_limit(
|
||||||
|
feedback_interaction_pairs: List[Tuple[Tuple, Tuple]], last_n_limit: Optional[int]
|
||||||
|
) -> List[Tuple[Tuple, Tuple]]:
|
||||||
|
"""Sort by interaction created_at desc with updated_at fallback, then limit."""
|
||||||
|
|
||||||
|
def _recency_key(pair):
|
||||||
|
_, (_, interaction_props) = pair
|
||||||
|
created_at = interaction_props.get("created_at") or ""
|
||||||
|
updated_at = interaction_props.get("updated_at") or ""
|
||||||
|
return (created_at, updated_at)
|
||||||
|
|
||||||
|
sorted_pairs = sorted(feedback_interaction_pairs, key=_recency_key, reverse=True)
|
||||||
|
return sorted_pairs[: last_n_limit or len(sorted_pairs)]
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_human_readable_context_summary(
|
||||||
|
question_text: str, raw_context_text: str
|
||||||
|
) -> str:
|
||||||
|
"""Generate a concise human-readable summary for given context."""
|
||||||
|
try:
|
||||||
|
prompt = read_query_prompt("feedback_user_context_prompt.txt")
|
||||||
|
rendered = prompt.format(question=question_text, context=raw_context_text)
|
||||||
|
return await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=rendered, system_prompt="", response_model=str
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("Failed to summarize context", error=str(exc))
|
||||||
|
return raw_context_text or ""
|
||||||
|
|
||||||
|
|
||||||
|
def _has_required_feedback_fields(record: Dict) -> bool:
|
||||||
|
"""Validate required fields exist in the item dict."""
|
||||||
|
required_fields = [
|
||||||
|
"question",
|
||||||
|
"answer",
|
||||||
|
"context",
|
||||||
|
"feedback_text",
|
||||||
|
"feedback_id",
|
||||||
|
"interaction_id",
|
||||||
|
]
|
||||||
|
return all(record.get(field_name) is not None for field_name in required_fields)
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_feedback_interaction_record(
|
||||||
|
feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict
|
||||||
|
) -> Optional[Dict]:
|
||||||
|
"""Build a single feedback-interaction record with context summary."""
|
||||||
|
try:
|
||||||
|
question_text = interaction_props.get("question")
|
||||||
|
original_answer_text = interaction_props.get("answer")
|
||||||
|
raw_context_text = interaction_props.get("context", "")
|
||||||
|
feedback_text = feedback_props.get("feedback") or feedback_props.get("text") or ""
|
||||||
|
|
||||||
|
context_summary_text = await _generate_human_readable_context_summary(
|
||||||
|
question_text or "", raw_context_text
|
||||||
|
)
|
||||||
|
|
||||||
|
feedback_interaction_record = {
|
||||||
|
"question": question_text,
|
||||||
|
"answer": original_answer_text,
|
||||||
|
"context": context_summary_text,
|
||||||
|
"feedback_text": feedback_text,
|
||||||
|
"feedback_id": UUID(str(feedback_node_id)),
|
||||||
|
"interaction_id": UUID(str(interaction_node_id)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if _has_required_feedback_fields(feedback_interaction_record):
|
||||||
|
return feedback_interaction_record
|
||||||
|
else:
|
||||||
|
logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id))
|
||||||
|
return None
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.error("Failed to process feedback pair", error=str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_feedback_interaction_records(
|
||||||
|
matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]],
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Build all feedback-interaction records from matched pairs."""
|
||||||
|
feedback_interaction_records: List[Dict] = []
|
||||||
|
for (feedback_node_id, feedback_props), (
|
||||||
|
interaction_node_id,
|
||||||
|
interaction_props,
|
||||||
|
) in matched_feedback_interaction_pairs:
|
||||||
|
record = await _build_feedback_interaction_record(
|
||||||
|
feedback_node_id, feedback_props, interaction_node_id, interaction_props
|
||||||
|
)
|
||||||
|
if record:
|
||||||
|
feedback_interaction_records.append(record)
|
||||||
|
return feedback_interaction_records
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_feedback_interactions(
|
||||||
|
subgraphs: List, last_n: Optional[int] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Extract negative feedback-interaction pairs; fetch internally and use last_n param for limiting."""
|
||||||
|
graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data()
|
||||||
|
if not graph_nodes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes)
|
||||||
|
negative_feedback_nodes = filter_negative_feedback(feedback_nodes)
|
||||||
|
if not negative_feedback_nodes:
|
||||||
|
logger.info("No negative feedback found; returning empty list")
|
||||||
|
return []
|
||||||
|
|
||||||
|
matched_feedback_interaction_pairs = _match_feedback_nodes_to_interactions_by_edges(
|
||||||
|
negative_feedback_nodes, interaction_nodes, graph_edges
|
||||||
|
)
|
||||||
|
if not matched_feedback_interaction_pairs:
|
||||||
|
logger.info("No feedback-to-interaction matches found; returning empty list")
|
||||||
|
return []
|
||||||
|
|
||||||
|
matched_feedback_interaction_pairs = _sort_pairs_by_recency_and_limit(
|
||||||
|
matched_feedback_interaction_pairs, last_n
|
||||||
|
)
|
||||||
|
|
||||||
|
feedback_interaction_records = await _build_feedback_interaction_records(
|
||||||
|
matched_feedback_interaction_pairs
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Extracted feedback pairs", count=len(feedback_interaction_records))
|
||||||
|
return feedback_interaction_records
|
||||||
74
examples/python/feedback_enrichment_minimal_example.py
Normal file
74
examples/python/feedback_enrichment_minimal_example.py
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.api.v1.search import SearchType
|
||||||
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
|
|
||||||
|
from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
|
||||||
|
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
|
||||||
|
|
||||||
|
|
||||||
|
CONVERSATION = [
|
||||||
|
"Alice: Hey, Bob. Did you talk to Mallory?",
|
||||||
|
"Bob: Yeah, I just saw her before coming here.",
|
||||||
|
"Alice: Then she told you to bring my documents, right?",
|
||||||
|
"Bob: Uh… not exactly. She said you wanted me to bring you donuts. Which sounded kind of odd…",
|
||||||
|
"Alice: Ugh, she’s so annoying. Thanks for the donuts anyway!",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_conversation_and_graph(conversation):
|
||||||
|
"""Prune data/system, add conversation, cognify."""
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await cognee.add(conversation)
|
||||||
|
await cognee.cognify()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_question_and_submit_feedback(question_text: str) -> bool:
|
||||||
|
"""Ask question, submit feedback based on correctness, and return correctness flag."""
|
||||||
|
result = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text=question_text,
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
answer_text = str(result).lower()
|
||||||
|
mentions_mallory = "mallory" in answer_text
|
||||||
|
feedback_text = (
|
||||||
|
"Great answers, very helpful!"
|
||||||
|
if mentions_mallory
|
||||||
|
else "The answer about Bob and donuts was wrong."
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK,
|
||||||
|
query_text=feedback_text,
|
||||||
|
last_k=2,
|
||||||
|
)
|
||||||
|
return mentions_mallory
|
||||||
|
|
||||||
|
|
||||||
|
async def run_feedback_enrichment_memify(last_n: int = 5):
|
||||||
|
"""Execute memify with extraction and answer improvement tasks."""
|
||||||
|
# Instantiate tasks with their own kwargs
|
||||||
|
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
|
||||||
|
enrichment_tasks = [
|
||||||
|
Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20)
|
||||||
|
]
|
||||||
|
await cognee.memify(
|
||||||
|
extraction_tasks=extraction_tasks,
|
||||||
|
enrichment_tasks=enrichment_tasks,
|
||||||
|
data=[{}], # A placeholder to prevent fetching the entire graph
|
||||||
|
dataset="feedback_enrichment_minimal",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# await initialize_conversation_and_graph(CONVERSATION)
|
||||||
|
# is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
|
||||||
|
is_correct = False
|
||||||
|
if not is_correct:
|
||||||
|
await run_feedback_enrichment_memify(last_n=5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Add table
Reference in a new issue