feat: feedback enrichment (#1571)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> - Automatically finds negative user feedback and generates better answers - All tasks work with the same FeedbackEnrichment DataPoint that gets filled out as it moves through the memify pipeline - Creates new nodes and edges in the knowledge graph, linking improved answers back to the original feedback and interactions - Includes a complete example showing how to set up a conversation, ask questions, submit feedback, and run the enrichment pipeline when answers are wrong ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
d682f2e2e8
16 changed files with 1128 additions and 85 deletions
28
.github/workflows/e2e_tests.yml
vendored
28
.github/workflows/e2e_tests.yml
vendored
|
|
@ -358,6 +358,34 @@ jobs:
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/tasks/entity_extraction/entity_extraction_test.py
|
run: uv run python ./cognee/tests/tasks/entity_extraction/entity_extraction_test.py
|
||||||
|
|
||||||
|
test-feedback-enrichment:
|
||||||
|
name: Test Feedback Enrichment
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Dependencies already installed
|
||||||
|
run: echo "Dependencies already installed in setup"
|
||||||
|
|
||||||
|
- name: Run Feedback Enrichment Test
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
run: uv run python ./cognee/tests/test_feedback_enrichment.py
|
||||||
|
|
||||||
run_conversation_sessions_test:
|
run_conversation_sessions_test:
|
||||||
name: Conversation sessions test
|
name: Conversation sessions test
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
|
||||||
|
|
@ -1366,9 +1366,15 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
params[param_name] = values
|
params[param_name] = values
|
||||||
|
|
||||||
where_clause = " AND ".join(where_clauses)
|
where_clause = " AND ".join(where_clauses)
|
||||||
nodes_query = (
|
nodes_query = f"""
|
||||||
f"MATCH (n:Node) WHERE {where_clause} RETURN n.id, {{properties: n.properties}}"
|
MATCH (n:Node)
|
||||||
)
|
WHERE {where_clause}
|
||||||
|
RETURN n.id, {{
|
||||||
|
name: n.name,
|
||||||
|
type: n.type,
|
||||||
|
properties: n.properties
|
||||||
|
}}
|
||||||
|
"""
|
||||||
edges_query = f"""
|
edges_query = f"""
|
||||||
MATCH (n1:Node)-[r:EDGE]->(n2:Node)
|
MATCH (n1:Node)-[r:EDGE]->(n2:Node)
|
||||||
WHERE {where_clause.replace("n.", "n1.")} AND {where_clause.replace("n.", "n2.")}
|
WHERE {where_clause.replace("n.", "n1.")} AND {where_clause.replace("n.", "n2.")}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
A question was previously answered, but the answer received negative feedback.
|
||||||
|
Please reconsider and improve the response.
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
Context originally used: {context}
|
||||||
|
Previous answer: {wrong_answer}
|
||||||
|
Feedback on that answer: {negative_feedback}
|
||||||
|
|
||||||
|
Task: Provide a better response. The new answer should be short and direct.
|
||||||
|
Then explain briefly why this answer is better.
|
||||||
|
|
||||||
|
Format your reply as:
|
||||||
|
Answer: <improved answer>
|
||||||
|
Explanation: <short explanation>
|
||||||
13
cognee/infrastructure/llm/prompts/feedback_report_prompt.txt
Normal file
13
cognee/infrastructure/llm/prompts/feedback_report_prompt.txt
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
Write a concise, stand-alone paragraph that explains the correct answer to the question below.
|
||||||
|
The paragraph should read naturally on its own, providing all necessary context and reasoning
|
||||||
|
so the answer is clear and well-supported.
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
Correct answer: {improved_answer}
|
||||||
|
Supporting context: {new_context}
|
||||||
|
|
||||||
|
Your paragraph should:
|
||||||
|
- First sentence clearly states the correct answer as a full sentence
|
||||||
|
- Remainder flows from first sentence and provides explanation based on context
|
||||||
|
- Use simple, direct language that is easy to follow
|
||||||
|
- Use shorter sentences, no long-winded explanations
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
Question: {question}
|
||||||
|
Context: {context}
|
||||||
|
|
||||||
|
Provide a one paragraph human readable summary of this interaction context,
|
||||||
|
listing all the relevant facts and information in a simple and direct way.
|
||||||
|
|
@ -1,10 +1,15 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
from typing import Optional, List, Type, Any
|
from typing import Optional, List, Type, Any
|
||||||
|
from pydantic import BaseModel
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
|
from cognee.modules.retrieval.utils.completion import (
|
||||||
|
generate_structured_completion,
|
||||||
|
summarize_text,
|
||||||
|
)
|
||||||
from cognee.modules.retrieval.utils.session_cache import (
|
from cognee.modules.retrieval.utils.session_cache import (
|
||||||
save_conversation_history,
|
save_conversation_history,
|
||||||
get_conversation_history,
|
get_conversation_history,
|
||||||
|
|
@ -17,6 +22,20 @@ from cognee.infrastructure.databases.cache.config import CacheConfig
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _as_answer_text(completion: Any) -> str:
|
||||||
|
"""Convert completion to human-readable text for validation and follow-up prompts."""
|
||||||
|
if isinstance(completion, str):
|
||||||
|
return completion
|
||||||
|
if isinstance(completion, BaseModel):
|
||||||
|
# Add notice that this is a structured response
|
||||||
|
json_str = completion.model_dump_json(indent=2)
|
||||||
|
return f"[Structured Response]\n{json_str}"
|
||||||
|
try:
|
||||||
|
return json.dumps(completion, indent=2)
|
||||||
|
except TypeError:
|
||||||
|
return str(completion)
|
||||||
|
|
||||||
|
|
||||||
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
"""
|
"""
|
||||||
Handles graph completion by generating responses based on a series of interactions with
|
Handles graph completion by generating responses based on a series of interactions with
|
||||||
|
|
@ -25,6 +44,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
questions based on reasoning. The public methods are:
|
questions based on reasoning. The public methods are:
|
||||||
|
|
||||||
- get_completion
|
- get_completion
|
||||||
|
- get_structured_completion
|
||||||
|
|
||||||
Instance variables include:
|
Instance variables include:
|
||||||
- validation_system_prompt_path
|
- validation_system_prompt_path
|
||||||
|
|
@ -61,6 +81,155 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
self.followup_system_prompt_path = followup_system_prompt_path
|
self.followup_system_prompt_path = followup_system_prompt_path
|
||||||
self.followup_user_prompt_path = followup_user_prompt_path
|
self.followup_user_prompt_path = followup_user_prompt_path
|
||||||
|
|
||||||
|
async def _run_cot_completion(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[List[Edge]] = None,
|
||||||
|
conversation_history: str = "",
|
||||||
|
max_iter: int = 4,
|
||||||
|
response_model: Type = str,
|
||||||
|
) -> tuple[Any, str, List[Edge]]:
|
||||||
|
"""
|
||||||
|
Run chain-of-thought completion with optional structured output.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- query: User query
|
||||||
|
- context: Optional pre-fetched context edges
|
||||||
|
- conversation_history: Optional conversation history string
|
||||||
|
- max_iter: Maximum CoT iterations
|
||||||
|
- response_model: Type for structured output (str for plain text)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- completion_result: The generated completion (string or structured model)
|
||||||
|
- context_text: The resolved context text
|
||||||
|
- triplets: The list of triplets used
|
||||||
|
"""
|
||||||
|
followup_question = ""
|
||||||
|
triplets = []
|
||||||
|
completion = ""
|
||||||
|
|
||||||
|
for round_idx in range(max_iter + 1):
|
||||||
|
if round_idx == 0:
|
||||||
|
if context is None:
|
||||||
|
triplets = await self.get_context(query)
|
||||||
|
context_text = await self.resolve_edges_to_text(triplets)
|
||||||
|
else:
|
||||||
|
context_text = await self.resolve_edges_to_text(context)
|
||||||
|
else:
|
||||||
|
triplets += await self.get_context(followup_question)
|
||||||
|
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
||||||
|
|
||||||
|
completion = await generate_structured_completion(
|
||||||
|
query=query,
|
||||||
|
context=context_text,
|
||||||
|
user_prompt_path=self.user_prompt_path,
|
||||||
|
system_prompt_path=self.system_prompt_path,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
conversation_history=conversation_history if conversation_history else None,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
||||||
|
|
||||||
|
if round_idx < max_iter:
|
||||||
|
answer_text = _as_answer_text(completion)
|
||||||
|
valid_args = {"query": query, "answer": answer_text, "context": context_text}
|
||||||
|
valid_user_prompt = render_prompt(
|
||||||
|
filename=self.validation_user_prompt_path, context=valid_args
|
||||||
|
)
|
||||||
|
valid_system_prompt = read_query_prompt(
|
||||||
|
prompt_file_name=self.validation_system_prompt_path
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning = await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=valid_user_prompt,
|
||||||
|
system_prompt=valid_system_prompt,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
followup_args = {"query": query, "answer": answer_text, "reasoning": reasoning}
|
||||||
|
followup_prompt = render_prompt(
|
||||||
|
filename=self.followup_user_prompt_path, context=followup_args
|
||||||
|
)
|
||||||
|
followup_system = read_query_prompt(
|
||||||
|
prompt_file_name=self.followup_system_prompt_path
|
||||||
|
)
|
||||||
|
|
||||||
|
followup_question = await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=followup_prompt, system_prompt=followup_system, response_model=str
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion, context_text, triplets
|
||||||
|
|
||||||
|
async def get_structured_completion(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context: Optional[List[Edge]] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
max_iter: int = 4,
|
||||||
|
response_model: Type = str,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Generate structured completion responses based on a user query and contextual information.
|
||||||
|
|
||||||
|
This method applies the same chain-of-thought logic as get_completion but returns
|
||||||
|
structured output using the provided response model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- query (str): The user's query to be processed and answered.
|
||||||
|
- context (Optional[List[Edge]]): Optional context that may assist in answering the query.
|
||||||
|
If not provided, it will be fetched based on the query. (default None)
|
||||||
|
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
||||||
|
defaults to 'default_session'. (default None)
|
||||||
|
- max_iter: The maximum number of iterations to refine the answer and generate
|
||||||
|
follow-up questions. (default 4)
|
||||||
|
- response_model (Type): The Pydantic model type for structured output. (default str)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- Any: The generated structured completion based on the response model.
|
||||||
|
"""
|
||||||
|
# Check if session saving is enabled
|
||||||
|
cache_config = CacheConfig()
|
||||||
|
user = session_user.get()
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
session_save = user_id and cache_config.caching
|
||||||
|
|
||||||
|
# Load conversation history if enabled
|
||||||
|
conversation_history = ""
|
||||||
|
if session_save:
|
||||||
|
conversation_history = await get_conversation_history(session_id=session_id)
|
||||||
|
|
||||||
|
completion, context_text, triplets = await self._run_cot_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
conversation_history=conversation_history,
|
||||||
|
max_iter=max_iter,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.save_interaction and context and triplets and completion:
|
||||||
|
await self.save_qa(
|
||||||
|
question=query, answer=str(completion), context=context_text, triplets=triplets
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to session cache if enabled
|
||||||
|
if session_save:
|
||||||
|
context_summary = await summarize_text(context_text)
|
||||||
|
await save_conversation_history(
|
||||||
|
query=query,
|
||||||
|
context_summary=context_summary,
|
||||||
|
answer=str(completion),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion
|
||||||
|
|
||||||
async def get_completion(
|
async def get_completion(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -92,82 +261,12 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
||||||
|
|
||||||
- List[str]: A list containing the generated answer to the user's query.
|
- List[str]: A list containing the generated answer to the user's query.
|
||||||
"""
|
"""
|
||||||
followup_question = ""
|
completion = await self.get_structured_completion(
|
||||||
triplets = []
|
query=query,
|
||||||
completion = ""
|
context=context,
|
||||||
|
session_id=session_id,
|
||||||
# Retrieve conversation history if session saving is enabled
|
max_iter=max_iter,
|
||||||
cache_config = CacheConfig()
|
response_model=str,
|
||||||
user = session_user.get()
|
)
|
||||||
user_id = getattr(user, "id", None)
|
|
||||||
session_save = user_id and cache_config.caching
|
|
||||||
|
|
||||||
conversation_history = ""
|
|
||||||
if session_save:
|
|
||||||
conversation_history = await get_conversation_history(session_id=session_id)
|
|
||||||
|
|
||||||
for round_idx in range(max_iter + 1):
|
|
||||||
if round_idx == 0:
|
|
||||||
if context is None:
|
|
||||||
triplets = await self.get_context(query)
|
|
||||||
context_text = await self.resolve_edges_to_text(triplets)
|
|
||||||
else:
|
|
||||||
context_text = await self.resolve_edges_to_text(context)
|
|
||||||
else:
|
|
||||||
triplets += await self.get_context(followup_question)
|
|
||||||
context_text = await self.resolve_edges_to_text(list(set(triplets)))
|
|
||||||
|
|
||||||
completion = await generate_completion(
|
|
||||||
query=query,
|
|
||||||
context=context_text,
|
|
||||||
user_prompt_path=self.user_prompt_path,
|
|
||||||
system_prompt_path=self.system_prompt_path,
|
|
||||||
system_prompt=self.system_prompt,
|
|
||||||
conversation_history=conversation_history if session_save else None,
|
|
||||||
)
|
|
||||||
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
|
|
||||||
if round_idx < max_iter:
|
|
||||||
valid_args = {"query": query, "answer": completion, "context": context_text}
|
|
||||||
valid_user_prompt = render_prompt(
|
|
||||||
filename=self.validation_user_prompt_path, context=valid_args
|
|
||||||
)
|
|
||||||
valid_system_prompt = read_query_prompt(
|
|
||||||
prompt_file_name=self.validation_system_prompt_path
|
|
||||||
)
|
|
||||||
|
|
||||||
reasoning = await LLMGateway.acreate_structured_output(
|
|
||||||
text_input=valid_user_prompt,
|
|
||||||
system_prompt=valid_system_prompt,
|
|
||||||
response_model=str,
|
|
||||||
)
|
|
||||||
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
|
|
||||||
followup_prompt = render_prompt(
|
|
||||||
filename=self.followup_user_prompt_path, context=followup_args
|
|
||||||
)
|
|
||||||
followup_system = read_query_prompt(
|
|
||||||
prompt_file_name=self.followup_system_prompt_path
|
|
||||||
)
|
|
||||||
|
|
||||||
followup_question = await LLMGateway.acreate_structured_output(
|
|
||||||
text_input=followup_prompt, system_prompt=followup_system, response_model=str
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.save_interaction and context and triplets and completion:
|
|
||||||
await self.save_qa(
|
|
||||||
question=query, answer=completion, context=context_text, triplets=triplets
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save to session cache
|
|
||||||
if session_save:
|
|
||||||
context_summary = await summarize_text(context_text)
|
|
||||||
await save_conversation_history(
|
|
||||||
query=query,
|
|
||||||
context_summary=context_summary,
|
|
||||||
answer=completion,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [completion]
|
return [completion]
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,18 @@
|
||||||
from typing import Optional
|
from typing import Optional, Type, Any
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
|
||||||
|
|
||||||
|
|
||||||
async def generate_completion(
|
async def generate_structured_completion(
|
||||||
query: str,
|
query: str,
|
||||||
context: str,
|
context: str,
|
||||||
user_prompt_path: str,
|
user_prompt_path: str,
|
||||||
system_prompt_path: str,
|
system_prompt_path: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
conversation_history: Optional[str] = None,
|
conversation_history: Optional[str] = None,
|
||||||
) -> str:
|
response_model: Type = str,
|
||||||
"""Generates a completion using LLM with given context and prompts."""
|
) -> Any:
|
||||||
|
"""Generates a structured completion using LLM with given context and prompts."""
|
||||||
args = {"question": query, "context": context}
|
args = {"question": query, "context": context}
|
||||||
user_prompt = render_prompt(user_prompt_path, args)
|
user_prompt = render_prompt(user_prompt_path, args)
|
||||||
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
|
||||||
|
|
@ -23,6 +24,26 @@ async def generate_completion(
|
||||||
return await LLMGateway.acreate_structured_output(
|
return await LLMGateway.acreate_structured_output(
|
||||||
text_input=user_prompt,
|
text_input=user_prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_completion(
|
||||||
|
query: str,
|
||||||
|
context: str,
|
||||||
|
user_prompt_path: str,
|
||||||
|
system_prompt_path: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
conversation_history: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generates a completion using LLM with given context and prompts."""
|
||||||
|
return await generate_structured_completion(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
user_prompt_path=user_prompt_path,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
conversation_history=conversation_history,
|
||||||
response_model=str,
|
response_model=str,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
13
cognee/tasks/feedback/__init__.py
Normal file
13
cognee/tasks/feedback/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
from .extract_feedback_interactions import extract_feedback_interactions
|
||||||
|
from .generate_improved_answers import generate_improved_answers
|
||||||
|
from .create_enrichments import create_enrichments
|
||||||
|
from .link_enrichments_to_feedback import link_enrichments_to_feedback
|
||||||
|
from .models import FeedbackEnrichment
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"extract_feedback_interactions",
|
||||||
|
"generate_improved_answers",
|
||||||
|
"create_enrichments",
|
||||||
|
"link_enrichments_to_feedback",
|
||||||
|
"FeedbackEnrichment",
|
||||||
|
]
|
||||||
84
cognee/tasks/feedback/create_enrichments.py
Normal file
84
cognee/tasks/feedback/create_enrichments.py
Normal file
|
|
@ -0,0 +1,84 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.engine.models import NodeSet
|
||||||
|
|
||||||
|
from .models import FeedbackEnrichment
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("create_enrichments")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_enrichments(enrichments: List[FeedbackEnrichment]) -> bool:
|
||||||
|
"""Validate that all enrichments contain required fields for completion."""
|
||||||
|
return all(
|
||||||
|
enrichment.question is not None
|
||||||
|
and enrichment.original_answer is not None
|
||||||
|
and enrichment.improved_answer is not None
|
||||||
|
and enrichment.new_context is not None
|
||||||
|
and enrichment.feedback_id is not None
|
||||||
|
and enrichment.interaction_id is not None
|
||||||
|
for enrichment in enrichments
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_enrichment_report(
|
||||||
|
question: str, improved_answer: str, new_context: str, report_prompt_location: str
|
||||||
|
) -> str:
|
||||||
|
"""Generate educational report using feedback report prompt."""
|
||||||
|
try:
|
||||||
|
prompt_template = read_query_prompt(report_prompt_location)
|
||||||
|
rendered_prompt = prompt_template.format(
|
||||||
|
question=question,
|
||||||
|
improved_answer=improved_answer,
|
||||||
|
new_context=new_context,
|
||||||
|
)
|
||||||
|
return await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=rendered_prompt,
|
||||||
|
system_prompt="You are a helpful assistant that creates educational content.",
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to generate enrichment report", error=str(exc), question=question)
|
||||||
|
return f"Educational content for: {question} - {improved_answer}"
|
||||||
|
|
||||||
|
|
||||||
|
async def create_enrichments(
|
||||||
|
enrichments: List[FeedbackEnrichment],
|
||||||
|
report_prompt_location: str = "feedback_report_prompt.txt",
|
||||||
|
) -> List[FeedbackEnrichment]:
|
||||||
|
"""Fill text and belongs_to_set fields of existing FeedbackEnrichment DataPoints."""
|
||||||
|
if not enrichments:
|
||||||
|
logger.info("No enrichments provided; returning empty list")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not _validate_enrichments(enrichments):
|
||||||
|
logger.error("Input validation failed; missing required fields")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info("Completing enrichments", count=len(enrichments))
|
||||||
|
|
||||||
|
nodeset = NodeSet(id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment")
|
||||||
|
|
||||||
|
completed_enrichments: List[FeedbackEnrichment] = []
|
||||||
|
|
||||||
|
for enrichment in enrichments:
|
||||||
|
report_text = await _generate_enrichment_report(
|
||||||
|
enrichment.question,
|
||||||
|
enrichment.improved_answer,
|
||||||
|
enrichment.new_context,
|
||||||
|
report_prompt_location,
|
||||||
|
)
|
||||||
|
|
||||||
|
enrichment.text = report_text
|
||||||
|
enrichment.belongs_to_set = [nodeset]
|
||||||
|
|
||||||
|
completed_enrichments.append(enrichment)
|
||||||
|
|
||||||
|
logger.info("Completed enrichments", successful=len(completed_enrichments))
|
||||||
|
return completed_enrichments
|
||||||
230
cognee/tasks/feedback/extract_feedback_interactions.py
Normal file
230
cognee/tasks/feedback/extract_feedback_interactions.py
Normal file
|
|
@ -0,0 +1,230 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
|
||||||
|
from .models import FeedbackEnrichment
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("extract_feedback_interactions")
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_negative_feedback(feedback_nodes):
|
||||||
|
"""Filter for negative sentiment feedback using precise sentiment classification."""
|
||||||
|
return [
|
||||||
|
(node_id, props)
|
||||||
|
for node_id, props in feedback_nodes
|
||||||
|
if (props.get("sentiment", "").casefold() == "negative" or props.get("score", 0) < 0)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_normalized_id(node_id, props) -> str:
|
||||||
|
"""Return Cognee node id preference: props.id → props.node_id → raw node_id."""
|
||||||
|
return str(props.get("id") or props.get("node_id") or node_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_feedback_and_interaction_graph_data() -> Tuple[List, List]:
|
||||||
|
"""Fetch feedback and interaction nodes with edges from graph engine."""
|
||||||
|
try:
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
attribute_filters = [{"type": ["CogneeUserFeedback", "CogneeUserInteraction"]}]
|
||||||
|
return await graph_engine.get_filtered_graph_data(attribute_filters)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.error("Failed to fetch filtered graph data", error=str(exc))
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
|
def _separate_feedback_and_interaction_nodes(graph_nodes: List) -> Tuple[List, List]:
|
||||||
|
"""Split nodes into feedback and interaction groups by type field."""
|
||||||
|
feedback_nodes = [
|
||||||
|
(_get_normalized_id(node_id, props), props)
|
||||||
|
for node_id, props in graph_nodes
|
||||||
|
if props.get("type") == "CogneeUserFeedback"
|
||||||
|
]
|
||||||
|
interaction_nodes = [
|
||||||
|
(_get_normalized_id(node_id, props), props)
|
||||||
|
for node_id, props in graph_nodes
|
||||||
|
if props.get("type") == "CogneeUserInteraction"
|
||||||
|
]
|
||||||
|
return feedback_nodes, interaction_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def _match_feedback_nodes_to_interactions_by_edges(
|
||||||
|
feedback_nodes: List, interaction_nodes: List, graph_edges: List
|
||||||
|
) -> List[Tuple[Tuple, Tuple]]:
|
||||||
|
"""Match feedback to interactions using gives_feedback_to edges."""
|
||||||
|
interaction_by_id = {node_id: (node_id, props) for node_id, props in interaction_nodes}
|
||||||
|
feedback_by_id = {node_id: (node_id, props) for node_id, props in feedback_nodes}
|
||||||
|
feedback_edges = [
|
||||||
|
(source_id, target_id)
|
||||||
|
for source_id, target_id, rel, _ in graph_edges
|
||||||
|
if rel == "gives_feedback_to"
|
||||||
|
]
|
||||||
|
|
||||||
|
feedback_interaction_pairs: List[Tuple[Tuple, Tuple]] = []
|
||||||
|
for source_id, target_id in feedback_edges:
|
||||||
|
source_id_str, target_id_str = str(source_id), str(target_id)
|
||||||
|
|
||||||
|
feedback_node = feedback_by_id.get(source_id_str)
|
||||||
|
interaction_node = interaction_by_id.get(target_id_str)
|
||||||
|
|
||||||
|
if feedback_node and interaction_node:
|
||||||
|
feedback_interaction_pairs.append((feedback_node, interaction_node))
|
||||||
|
|
||||||
|
return feedback_interaction_pairs
|
||||||
|
|
||||||
|
|
||||||
|
def _sort_pairs_by_recency_and_limit(
|
||||||
|
feedback_interaction_pairs: List[Tuple[Tuple, Tuple]], last_n_limit: Optional[int]
|
||||||
|
) -> List[Tuple[Tuple, Tuple]]:
|
||||||
|
"""Sort by interaction created_at desc with updated_at fallback, then limit."""
|
||||||
|
|
||||||
|
def _recency_key(pair):
|
||||||
|
_, (_, interaction_props) = pair
|
||||||
|
created_at = interaction_props.get("created_at") or ""
|
||||||
|
updated_at = interaction_props.get("updated_at") or ""
|
||||||
|
return (created_at, updated_at)
|
||||||
|
|
||||||
|
sorted_pairs = sorted(feedback_interaction_pairs, key=_recency_key, reverse=True)
|
||||||
|
return sorted_pairs[: last_n_limit or len(sorted_pairs)]
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_human_readable_context_summary(
|
||||||
|
question_text: str, raw_context_text: str
|
||||||
|
) -> str:
|
||||||
|
"""Generate a concise human-readable summary for given context."""
|
||||||
|
try:
|
||||||
|
prompt = read_query_prompt("feedback_user_context_prompt.txt")
|
||||||
|
rendered = prompt.format(question=question_text, context=raw_context_text)
|
||||||
|
return await LLMGateway.acreate_structured_output(
|
||||||
|
text_input=rendered, system_prompt="", response_model=str
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("Failed to summarize context", error=str(exc))
|
||||||
|
return raw_context_text or ""
|
||||||
|
|
||||||
|
|
||||||
|
def _has_required_feedback_fields(enrichment: FeedbackEnrichment) -> bool:
|
||||||
|
"""Validate required fields exist in the FeedbackEnrichment DataPoint."""
|
||||||
|
return (
|
||||||
|
enrichment.question is not None
|
||||||
|
and enrichment.original_answer is not None
|
||||||
|
and enrichment.context is not None
|
||||||
|
and enrichment.feedback_text is not None
|
||||||
|
and enrichment.feedback_id is not None
|
||||||
|
and enrichment.interaction_id is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_feedback_interaction_record(
|
||||||
|
feedback_node_id: str, feedback_props: Dict, interaction_node_id: str, interaction_props: Dict
|
||||||
|
) -> Optional[FeedbackEnrichment]:
|
||||||
|
"""Build a single FeedbackEnrichment DataPoint with context summary."""
|
||||||
|
try:
|
||||||
|
question_text = interaction_props.get("question")
|
||||||
|
original_answer_text = interaction_props.get("answer")
|
||||||
|
raw_context_text = interaction_props.get("context", "")
|
||||||
|
feedback_text = feedback_props.get("feedback") or feedback_props.get("text") or ""
|
||||||
|
|
||||||
|
context_summary_text = await _generate_human_readable_context_summary(
|
||||||
|
question_text or "", raw_context_text
|
||||||
|
)
|
||||||
|
|
||||||
|
enrichment = FeedbackEnrichment(
|
||||||
|
id=str(uuid5(NAMESPACE_OID, f"{question_text}_{interaction_node_id}")),
|
||||||
|
text="",
|
||||||
|
question=question_text,
|
||||||
|
original_answer=original_answer_text,
|
||||||
|
improved_answer="",
|
||||||
|
feedback_id=UUID(str(feedback_node_id)),
|
||||||
|
interaction_id=UUID(str(interaction_node_id)),
|
||||||
|
belongs_to_set=None,
|
||||||
|
context=context_summary_text,
|
||||||
|
feedback_text=feedback_text,
|
||||||
|
new_context="",
|
||||||
|
explanation="",
|
||||||
|
)
|
||||||
|
|
||||||
|
if _has_required_feedback_fields(enrichment):
|
||||||
|
return enrichment
|
||||||
|
else:
|
||||||
|
logger.warning("Skipping invalid feedback item", interaction=str(interaction_node_id))
|
||||||
|
return None
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.error("Failed to process feedback pair", error=str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_feedback_interaction_records(
|
||||||
|
matched_feedback_interaction_pairs: List[Tuple[Tuple, Tuple]],
|
||||||
|
) -> List[FeedbackEnrichment]:
|
||||||
|
"""Build all FeedbackEnrichment DataPoints from matched pairs."""
|
||||||
|
feedback_interaction_records: List[FeedbackEnrichment] = []
|
||||||
|
for (feedback_node_id, feedback_props), (
|
||||||
|
interaction_node_id,
|
||||||
|
interaction_props,
|
||||||
|
) in matched_feedback_interaction_pairs:
|
||||||
|
record = await _build_feedback_interaction_record(
|
||||||
|
feedback_node_id, feedback_props, interaction_node_id, interaction_props
|
||||||
|
)
|
||||||
|
if record:
|
||||||
|
feedback_interaction_records.append(record)
|
||||||
|
return feedback_interaction_records
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_feedback_interactions(
|
||||||
|
data: Any, last_n: Optional[int] = None
|
||||||
|
) -> List[FeedbackEnrichment]:
|
||||||
|
"""Extract negative feedback-interaction pairs and create FeedbackEnrichment DataPoints."""
|
||||||
|
if not data or data == [{}]:
|
||||||
|
logger.info(
|
||||||
|
"No data passed to the extraction task (extraction task fetches data from graph directly)",
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_nodes, graph_edges = await _fetch_feedback_and_interaction_graph_data()
|
||||||
|
if not graph_nodes:
|
||||||
|
logger.warning("No graph nodes retrieved from database")
|
||||||
|
return []
|
||||||
|
|
||||||
|
feedback_nodes, interaction_nodes = _separate_feedback_and_interaction_nodes(graph_nodes)
|
||||||
|
logger.info(
|
||||||
|
"Retrieved nodes from graph",
|
||||||
|
total_nodes=len(graph_nodes),
|
||||||
|
feedback_nodes=len(feedback_nodes),
|
||||||
|
interaction_nodes=len(interaction_nodes),
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_feedback_nodes = _filter_negative_feedback(feedback_nodes)
|
||||||
|
logger.info(
|
||||||
|
"Filtered feedback nodes",
|
||||||
|
total_feedback=len(feedback_nodes),
|
||||||
|
negative_feedback=len(negative_feedback_nodes),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not negative_feedback_nodes:
|
||||||
|
logger.info("No negative feedback found; returning empty list")
|
||||||
|
return []
|
||||||
|
|
||||||
|
matched_feedback_interaction_pairs = _match_feedback_nodes_to_interactions_by_edges(
|
||||||
|
negative_feedback_nodes, interaction_nodes, graph_edges
|
||||||
|
)
|
||||||
|
if not matched_feedback_interaction_pairs:
|
||||||
|
logger.info("No feedback-to-interaction matches found; returning empty list")
|
||||||
|
return []
|
||||||
|
|
||||||
|
matched_feedback_interaction_pairs = _sort_pairs_by_recency_and_limit(
|
||||||
|
matched_feedback_interaction_pairs, last_n
|
||||||
|
)
|
||||||
|
|
||||||
|
feedback_interaction_records = await _build_feedback_interaction_records(
|
||||||
|
matched_feedback_interaction_pairs
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Extracted feedback pairs", count=len(feedback_interaction_records))
|
||||||
|
return feedback_interaction_records
|
||||||
130
cognee/tasks/feedback/generate_improved_answers.py
Normal file
130
cognee/tasks/feedback/generate_improved_answers.py
Normal file
|
|
@ -0,0 +1,130 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
|
||||||
|
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
|
from .models import FeedbackEnrichment
|
||||||
|
|
||||||
|
|
||||||
|
class ImprovedAnswerResponse(BaseModel):
|
||||||
|
"""Response model for improved answer generation containing answer and explanation."""
|
||||||
|
|
||||||
|
answer: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("generate_improved_answers")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_input_data(enrichments: List[FeedbackEnrichment]) -> bool:
|
||||||
|
"""Validate that input contains required fields for all enrichments."""
|
||||||
|
return all(
|
||||||
|
enrichment.question is not None
|
||||||
|
and enrichment.original_answer is not None
|
||||||
|
and enrichment.context is not None
|
||||||
|
and enrichment.feedback_text is not None
|
||||||
|
and enrichment.feedback_id is not None
|
||||||
|
and enrichment.interaction_id is not None
|
||||||
|
for enrichment in enrichments
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_reaction_prompt(
|
||||||
|
question: str, context: str, wrong_answer: str, negative_feedback: str
|
||||||
|
) -> str:
|
||||||
|
"""Render the feedback reaction prompt with provided variables."""
|
||||||
|
prompt_template = read_query_prompt("feedback_reaction_prompt.txt")
|
||||||
|
return prompt_template.format(
|
||||||
|
question=question,
|
||||||
|
context=context,
|
||||||
|
wrong_answer=wrong_answer,
|
||||||
|
negative_feedback=negative_feedback,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_improved_answer_for_single_interaction(
|
||||||
|
enrichment: FeedbackEnrichment, retriever, reaction_prompt_location: str
|
||||||
|
) -> Optional[FeedbackEnrichment]:
|
||||||
|
"""Generate improved answer for a single enrichment using structured retriever completion."""
|
||||||
|
try:
|
||||||
|
query_text = _render_reaction_prompt(
|
||||||
|
enrichment.question,
|
||||||
|
enrichment.context,
|
||||||
|
enrichment.original_answer,
|
||||||
|
enrichment.feedback_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
retrieved_context = await retriever.get_context(query_text)
|
||||||
|
completion = await retriever.get_structured_completion(
|
||||||
|
query=query_text,
|
||||||
|
context=retrieved_context,
|
||||||
|
response_model=ImprovedAnswerResponse,
|
||||||
|
max_iter=4,
|
||||||
|
)
|
||||||
|
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
|
||||||
|
|
||||||
|
if completion:
|
||||||
|
enrichment.improved_answer = completion.answer
|
||||||
|
enrichment.new_context = new_context_text
|
||||||
|
enrichment.explanation = completion.explanation
|
||||||
|
return enrichment
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to get structured completion from retriever", question=enrichment.question
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.error(
|
||||||
|
"Failed to generate improved answer",
|
||||||
|
error=str(exc),
|
||||||
|
question=enrichment.question,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_improved_answers(
|
||||||
|
enrichments: List[FeedbackEnrichment],
|
||||||
|
top_k: int = 20,
|
||||||
|
reaction_prompt_location: str = "feedback_reaction_prompt.txt",
|
||||||
|
) -> List[FeedbackEnrichment]:
|
||||||
|
"""Generate improved answers using CoT retriever and LLM."""
|
||||||
|
if not enrichments:
|
||||||
|
logger.info("No enrichments provided; returning empty list")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not _validate_input_data(enrichments):
|
||||||
|
logger.error("Input data validation failed; missing required fields")
|
||||||
|
return []
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever(
|
||||||
|
top_k=top_k,
|
||||||
|
save_interaction=False,
|
||||||
|
user_prompt_path="graph_context_for_question.txt",
|
||||||
|
system_prompt_path="answer_simple_question.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
improved_answers: List[FeedbackEnrichment] = []
|
||||||
|
|
||||||
|
for enrichment in enrichments:
|
||||||
|
result = await _generate_improved_answer_for_single_interaction(
|
||||||
|
enrichment, retriever, reaction_prompt_location
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
improved_answers.append(result)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to generate improved answer",
|
||||||
|
question=enrichment.question,
|
||||||
|
interaction_id=enrichment.interaction_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Generated improved answers", count=len(improved_answers))
|
||||||
|
return improved_answers
|
||||||
67
cognee/tasks/feedback/link_enrichments_to_feedback.py
Normal file
67
cognee/tasks/feedback/link_enrichments_to_feedback.py
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.tasks.storage import index_graph_edges
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
from .models import FeedbackEnrichment
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("link_enrichments_to_feedback")
|
||||||
|
|
||||||
|
|
||||||
|
def _create_edge_tuple(
|
||||||
|
source_id: UUID, target_id: UUID, relationship_name: str
|
||||||
|
) -> Tuple[UUID, UUID, str, dict]:
|
||||||
|
"""Create an edge tuple with proper properties structure."""
|
||||||
|
return (
|
||||||
|
source_id,
|
||||||
|
target_id,
|
||||||
|
relationship_name,
|
||||||
|
{
|
||||||
|
"relationship_name": relationship_name,
|
||||||
|
"source_node_id": source_id,
|
||||||
|
"target_node_id": target_id,
|
||||||
|
"ontology_valid": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def link_enrichments_to_feedback(
|
||||||
|
enrichments: List[FeedbackEnrichment],
|
||||||
|
) -> List[FeedbackEnrichment]:
|
||||||
|
"""Manually create edges from enrichments to original feedback/interaction nodes."""
|
||||||
|
if not enrichments:
|
||||||
|
logger.info("No enrichments provided; returning empty list")
|
||||||
|
return []
|
||||||
|
|
||||||
|
relationships = []
|
||||||
|
|
||||||
|
for enrichment in enrichments:
|
||||||
|
enrichment_id = enrichment.id
|
||||||
|
feedback_id = enrichment.feedback_id
|
||||||
|
interaction_id = enrichment.interaction_id
|
||||||
|
|
||||||
|
if enrichment_id and feedback_id:
|
||||||
|
enriches_feedback_edge = _create_edge_tuple(
|
||||||
|
enrichment_id, feedback_id, "enriches_feedback"
|
||||||
|
)
|
||||||
|
relationships.append(enriches_feedback_edge)
|
||||||
|
|
||||||
|
if enrichment_id and interaction_id:
|
||||||
|
improves_interaction_edge = _create_edge_tuple(
|
||||||
|
enrichment_id, interaction_id, "improves_interaction"
|
||||||
|
)
|
||||||
|
relationships.append(improves_interaction_edge)
|
||||||
|
|
||||||
|
if relationships:
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
await graph_engine.add_edges(relationships)
|
||||||
|
await index_graph_edges(relationships)
|
||||||
|
logger.info("Linking enrichments to feedback", edge_count=len(relationships))
|
||||||
|
|
||||||
|
logger.info("Linked enrichments", enrichment_count=len(enrichments))
|
||||||
|
return enrichments
|
||||||
26
cognee/tasks/feedback/models.py
Normal file
26
cognee/tasks/feedback/models.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.modules.engine.models import Entity, NodeSet
|
||||||
|
from cognee.tasks.temporal_graph.models import Event
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackEnrichment(DataPoint):
|
||||||
|
"""Minimal DataPoint for feedback enrichment that works with extract_graph_from_data."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
contains: Optional[List[Union[Entity, Event]]] = None
|
||||||
|
metadata: dict = {"index_fields": ["text"]}
|
||||||
|
|
||||||
|
question: str
|
||||||
|
original_answer: str
|
||||||
|
improved_answer: str
|
||||||
|
feedback_id: UUID
|
||||||
|
interaction_id: UUID
|
||||||
|
belongs_to_set: Optional[List[NodeSet]] = None
|
||||||
|
|
||||||
|
context: str = ""
|
||||||
|
feedback_text: str = ""
|
||||||
|
new_context: str = ""
|
||||||
|
explanation: str = ""
|
||||||
174
cognee/tests/test_feedback_enrichment.py
Normal file
174
cognee/tests/test_feedback_enrichment.py
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
"""
|
||||||
|
End-to-end integration test for feedback enrichment feature.
|
||||||
|
|
||||||
|
Tests the complete feedback enrichment pipeline:
|
||||||
|
1. Add data and cognify
|
||||||
|
2. Run search with save_interaction=True to create CogneeUserInteraction nodes
|
||||||
|
3. Submit feedback to create CogneeUserFeedback nodes
|
||||||
|
4. Run memify with feedback enrichment tasks to create FeedbackEnrichment nodes
|
||||||
|
5. Verify all nodes and edges are properly created and linked in the graph
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
|
from cognee.modules.search.types import SearchType
|
||||||
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.tasks.feedback.create_enrichments import create_enrichments
|
||||||
|
from cognee.tasks.feedback.extract_feedback_interactions import (
|
||||||
|
extract_feedback_interactions,
|
||||||
|
)
|
||||||
|
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
|
||||||
|
from cognee.tasks.feedback.link_enrichments_to_feedback import (
|
||||||
|
link_enrichments_to_feedback,
|
||||||
|
)
|
||||||
|
from cognee.tasks.graph import extract_graph_from_data
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
data_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(
|
||||||
|
pathlib.Path(__file__).parent,
|
||||||
|
".data_storage/test_feedback_enrichment",
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
cognee_directory_path = str(
|
||||||
|
pathlib.Path(
|
||||||
|
os.path.join(
|
||||||
|
pathlib.Path(__file__).parent,
|
||||||
|
".cognee_system/test_feedback_enrichment",
|
||||||
|
)
|
||||||
|
).resolve()
|
||||||
|
)
|
||||||
|
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
cognee.config.system_root_directory(cognee_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
dataset_name = "feedback_enrichment_test"
|
||||||
|
|
||||||
|
await cognee.add("Cognee turns documents into AI memory.", dataset_name)
|
||||||
|
await cognee.cognify([dataset_name])
|
||||||
|
|
||||||
|
question_text = "Say something."
|
||||||
|
result = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text=question_text,
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) > 0, "Search should return non-empty results"
|
||||||
|
|
||||||
|
feedback_text = "This answer was completely useless, my feedback is definitely negative."
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK,
|
||||||
|
query_text=feedback_text,
|
||||||
|
last_k=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
nodes_before, edges_before = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
interaction_nodes_before = [
|
||||||
|
(node_id, props)
|
||||||
|
for node_id, props in nodes_before
|
||||||
|
if props.get("type") == "CogneeUserInteraction"
|
||||||
|
]
|
||||||
|
feedback_nodes_before = [
|
||||||
|
(node_id, props)
|
||||||
|
for node_id, props in nodes_before
|
||||||
|
if props.get("type") == "CogneeUserFeedback"
|
||||||
|
]
|
||||||
|
|
||||||
|
edge_types_before = Counter(edge[2] for edge in edges_before)
|
||||||
|
|
||||||
|
assert len(interaction_nodes_before) >= 1, (
|
||||||
|
f"Expected at least 1 CogneeUserInteraction node, found {len(interaction_nodes_before)}"
|
||||||
|
)
|
||||||
|
assert len(feedback_nodes_before) >= 1, (
|
||||||
|
f"Expected at least 1 CogneeUserFeedback node, found {len(feedback_nodes_before)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for node_id, props in feedback_nodes_before:
|
||||||
|
sentiment = props.get("sentiment", "")
|
||||||
|
score = props.get("score", 0)
|
||||||
|
feedback_text = props.get("feedback", "")
|
||||||
|
logger.info(
|
||||||
|
"Feedback node created",
|
||||||
|
feedback=feedback_text,
|
||||||
|
sentiment=sentiment,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert edge_types_before.get("gives_feedback_to", 0) >= 1, (
|
||||||
|
f"Expected at least 1 'gives_feedback_to' edge, found {edge_types_before.get('gives_feedback_to', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
extraction_tasks = [Task(extract_feedback_interactions, last_n=5)]
|
||||||
|
enrichment_tasks = [
|
||||||
|
Task(generate_improved_answers, top_k=20),
|
||||||
|
Task(create_enrichments),
|
||||||
|
Task(
|
||||||
|
extract_graph_from_data,
|
||||||
|
graph_model=KnowledgeGraph,
|
||||||
|
task_config={"batch_size": 10},
|
||||||
|
),
|
||||||
|
Task(add_data_points, task_config={"batch_size": 10}),
|
||||||
|
Task(link_enrichments_to_feedback),
|
||||||
|
]
|
||||||
|
|
||||||
|
await cognee.memify(
|
||||||
|
extraction_tasks=extraction_tasks,
|
||||||
|
enrichment_tasks=enrichment_tasks,
|
||||||
|
data=[{}],
|
||||||
|
dataset="feedback_enrichment_test_memify",
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes_after, edges_after = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
enrichment_nodes = [
|
||||||
|
(node_id, props)
|
||||||
|
for node_id, props in nodes_after
|
||||||
|
if props.get("type") == "FeedbackEnrichment"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(enrichment_nodes) >= 1, (
|
||||||
|
f"Expected at least 1 FeedbackEnrichment node, found {len(enrichment_nodes)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for node_id, props in enrichment_nodes:
|
||||||
|
assert "text" in props, f"FeedbackEnrichment node {node_id} missing 'text' property"
|
||||||
|
|
||||||
|
enrichment_node_ids = {node_id for node_id, _ in enrichment_nodes}
|
||||||
|
edges_with_enrichments = [
|
||||||
|
edge
|
||||||
|
for edge in edges_after
|
||||||
|
if edge[0] in enrichment_node_ids or edge[1] in enrichment_node_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(edges_with_enrichments) >= 1, (
|
||||||
|
f"Expected enrichment nodes to have at least 1 edge, found {len(edges_with_enrichments)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
logger.info("All feedback enrichment tests passed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.low_level import setup, DataPoint
|
from cognee.low_level import setup, DataPoint
|
||||||
|
|
@ -10,6 +11,11 @@ from cognee.tasks.storage import add_data_points
|
||||||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnswer(BaseModel):
|
||||||
|
answer: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
class TestGraphCompletionCoTRetriever:
|
class TestGraphCompletionCoTRetriever:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graph_completion_cot_context_simple(self):
|
async def test_graph_completion_cot_context_simple(self):
|
||||||
|
|
@ -168,3 +174,48 @@ class TestGraphCompletionCoTRetriever:
|
||||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||||
"Answer must contain only non-empty strings"
|
"Answer must contain only non-empty strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_structured_completion(self):
|
||||||
|
system_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(system_directory_path)
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
class Company(DataPoint):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Person(DataPoint):
|
||||||
|
name: str
|
||||||
|
works_for: Company
|
||||||
|
|
||||||
|
company1 = Company(name="Figma")
|
||||||
|
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||||
|
|
||||||
|
entities = [company1, person1]
|
||||||
|
await add_data_points(entities)
|
||||||
|
|
||||||
|
retriever = GraphCompletionCotRetriever()
|
||||||
|
|
||||||
|
# Test with string response model (default)
|
||||||
|
string_answer = await retriever.get_structured_completion("Who works at Figma?")
|
||||||
|
assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
|
||||||
|
assert string_answer.strip(), "Answer should not be empty"
|
||||||
|
|
||||||
|
# Test with structured response model
|
||||||
|
structured_answer = await retriever.get_structured_completion(
|
||||||
|
"Who works at Figma?", response_model=TestAnswer
|
||||||
|
)
|
||||||
|
assert isinstance(structured_answer, TestAnswer), (
|
||||||
|
f"Expected TestAnswer, got {type(structured_answer).__name__}"
|
||||||
|
)
|
||||||
|
assert structured_answer.answer.strip(), "Answer field should not be empty"
|
||||||
|
assert structured_answer.explanation.strip(), "Explanation field should not be empty"
|
||||||
|
|
|
||||||
82
examples/python/feedback_enrichment_minimal_example.py
Normal file
82
examples/python/feedback_enrichment_minimal_example.py
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.api.v1.search import SearchType
|
||||||
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
|
from cognee.tasks.graph import extract_graph_from_data
|
||||||
|
from cognee.tasks.storage import add_data_points
|
||||||
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
|
|
||||||
|
from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
|
||||||
|
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
|
||||||
|
from cognee.tasks.feedback.create_enrichments import create_enrichments
|
||||||
|
from cognee.tasks.feedback.link_enrichments_to_feedback import link_enrichments_to_feedback
|
||||||
|
|
||||||
|
|
||||||
|
CONVERSATION = [
|
||||||
|
"Alice: Hey, Bob. Did you talk to Mallory?",
|
||||||
|
"Bob: Yeah, I just saw her before coming here.",
|
||||||
|
"Alice: Then she told you to bring my documents, right?",
|
||||||
|
"Bob: Uh… not exactly. She said you wanted me to bring you donuts. Which sounded kind of odd…",
|
||||||
|
"Alice: Ugh, she’s so annoying. Thanks for the donuts anyway!",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_conversation_and_graph(conversation):
|
||||||
|
"""Prune data/system, add conversation, cognify."""
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await cognee.add(conversation)
|
||||||
|
await cognee.cognify()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_question_and_submit_feedback(question_text: str) -> bool:
|
||||||
|
"""Ask question, submit feedback based on correctness, and return correctness flag."""
|
||||||
|
result = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text=question_text,
|
||||||
|
save_interaction=True,
|
||||||
|
)
|
||||||
|
answer_text = str(result).lower()
|
||||||
|
mentions_mallory = "mallory" in answer_text
|
||||||
|
feedback_text = (
|
||||||
|
"Great answers, very helpful!"
|
||||||
|
if mentions_mallory
|
||||||
|
else "The answer about Bob and donuts was wrong."
|
||||||
|
)
|
||||||
|
await cognee.search(
|
||||||
|
query_type=SearchType.FEEDBACK,
|
||||||
|
query_text=feedback_text,
|
||||||
|
last_k=1,
|
||||||
|
)
|
||||||
|
return mentions_mallory
|
||||||
|
|
||||||
|
|
||||||
|
async def run_feedback_enrichment_memify(last_n: int = 5):
|
||||||
|
"""Execute memify with extraction, answer improvement, enrichment creation, and graph processing tasks."""
|
||||||
|
# Instantiate tasks with their own kwargs
|
||||||
|
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
|
||||||
|
enrichment_tasks = [
|
||||||
|
Task(generate_improved_answers, top_k=20),
|
||||||
|
Task(create_enrichments),
|
||||||
|
Task(extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 10}),
|
||||||
|
Task(add_data_points, task_config={"batch_size": 10}),
|
||||||
|
Task(link_enrichments_to_feedback),
|
||||||
|
]
|
||||||
|
await cognee.memify(
|
||||||
|
extraction_tasks=extraction_tasks,
|
||||||
|
enrichment_tasks=enrichment_tasks,
|
||||||
|
data=[{}], # A placeholder to prevent fetching the entire graph
|
||||||
|
dataset="feedback_enrichment_minimal",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await initialize_conversation_and_graph(CONVERSATION)
|
||||||
|
is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
|
||||||
|
if not is_correct:
|
||||||
|
await run_feedback_enrichment_memify(last_n=5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Add table
Reference in a new issue