cognee/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py
2025-12-11 13:23:29 +01:00

312 lines
12 KiB
Python

import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from uuid import UUID, NAMESPACE_OID, uuid5
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment
from cognee.modules.engine.models import NodeSet
@pytest.fixture
def mock_feedback_evaluation():
"""Create a mock feedback evaluation."""
evaluation = MagicMock(spec=UserFeedbackEvaluation)
evaluation.evaluation = MagicMock()
evaluation.evaluation.value = "positive"
evaluation.score = 4.5
return evaluation
@pytest.fixture
def mock_graph_engine():
"""Create a mock graph engine."""
engine = AsyncMock()
engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
engine.add_edges = AsyncMock()
engine.apply_feedback_weight = AsyncMock()
return engine
class TestUserQAFeedback:
@pytest.mark.asyncio
async def test_init_default(self):
"""Test UserQAFeedback initialization with default last_k."""
retriever = UserQAFeedback()
assert retriever.last_k == 1
@pytest.mark.asyncio
async def test_init_custom_last_k(self):
"""Test UserQAFeedback initialization with custom last_k."""
retriever = UserQAFeedback(last_k=5)
assert retriever.last_k == 5
@pytest.mark.asyncio
async def test_add_feedback_success_with_relationships(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback successfully creates feedback with relationships."""
interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001"))
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(
return_value=[interaction_id_1, interaction_id_2]
)
feedback_text = "This answer was helpful"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
) as mock_add_data,
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
) as mock_index_edges,
):
retriever = UserQAFeedback(last_k=2)
result = await retriever.add_feedback(feedback_text)
assert result == [feedback_text]
mock_add_data.assert_awaited_once()
mock_graph_engine.add_edges.assert_awaited_once()
mock_index_edges.assert_awaited_once()
mock_graph_engine.apply_feedback_weight.assert_awaited_once()
# Verify add_edges was called with correct relationships
call_args = mock_graph_engine.add_edges.call_args[0][0]
assert len(call_args) == 2
assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text)
assert call_args[0][1] == UUID(interaction_id_1)
assert call_args[0][2] == "gives_feedback_to"
assert call_args[0][3]["relationship_name"] == "gives_feedback_to"
assert call_args[0][3]["ontology_valid"] is False
# Verify apply_feedback_weight was called with correct node IDs
weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"]
assert len(weight_call_args) == 2
assert interaction_id_1 in weight_call_args
assert interaction_id_2 in weight_call_args
@pytest.mark.asyncio
async def test_add_feedback_success_no_relationships(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback successfully creates feedback without relationships."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "This answer was helpful"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
) as mock_add_data,
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
) as mock_index_edges,
):
retriever = UserQAFeedback(last_k=1)
result = await retriever.add_feedback(feedback_text)
assert result == [feedback_text]
mock_add_data.assert_awaited_once()
# Should not call add_edges or index_graph_edges when no relationships
mock_graph_engine.add_edges.assert_not_awaited()
mock_index_edges.assert_not_awaited()
mock_graph_engine.apply_feedback_weight.assert_not_awaited()
@pytest.mark.asyncio
async def test_add_feedback_creates_correct_feedback_node(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback creates CogneeUserFeedback with correct attributes."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "This was a negative experience"
mock_feedback_evaluation.evaluation.value = "negative"
mock_feedback_evaluation.score = -3.0
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
) as mock_add_data,
):
retriever = UserQAFeedback()
await retriever.add_feedback(feedback_text)
# Verify add_data_points was called with correct CogneeUserFeedback
call_args = mock_add_data.call_args[1]["data_points"]
assert len(call_args) == 1
feedback_node = call_args[0]
assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text)
assert feedback_node.feedback == feedback_text
assert feedback_node.sentiment == "negative"
assert feedback_node.score == -3.0
assert isinstance(feedback_node.belongs_to_set, NodeSet)
assert feedback_node.belongs_to_set.name == "UserQAFeedbacks"
@pytest.mark.asyncio
async def test_add_feedback_calls_llm_with_correct_prompt(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback calls LLM with correct sentiment analysis prompt."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "Great answer!"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
) as mock_llm,
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback()
await retriever.add_feedback(feedback_text)
mock_llm.assert_awaited_once()
call_kwargs = mock_llm.call_args[1]
assert call_kwargs["text_input"] == feedback_text
assert "sentiment analysis assistant" in call_kwargs["system_prompt"]
assert call_kwargs["response_model"] == UserFeedbackEvaluation
@pytest.mark.asyncio
async def test_add_feedback_uses_last_k_parameter(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback uses last_k parameter when getting interaction IDs."""
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
feedback_text = "Test feedback"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback(last_k=5)
await retriever.add_feedback(feedback_text)
mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5)
@pytest.mark.asyncio
async def test_add_feedback_with_single_interaction(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback with single interaction ID."""
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
feedback_text = "Test feedback"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback()
result = await retriever.add_feedback(feedback_text)
assert result == [feedback_text]
# Should create relationship for the interaction
call_args = mock_graph_engine.add_edges.call_args[0][0]
assert len(call_args) == 1
assert call_args[0][1] == UUID(interaction_id)
@pytest.mark.asyncio
async def test_add_feedback_applies_weight_correctly(
self, mock_feedback_evaluation, mock_graph_engine
):
"""Test add_feedback applies feedback weight with correct score."""
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
mock_feedback_evaluation.score = 4.5
feedback_text = "Positive feedback"
with (
patch(
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
new_callable=AsyncMock,
return_value=mock_feedback_evaluation,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
new_callable=AsyncMock,
),
patch(
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
new_callable=AsyncMock,
),
):
retriever = UserQAFeedback()
await retriever.add_feedback(feedback_text)
mock_graph_engine.apply_feedback_weight.assert_awaited_once_with(
node_ids=[interaction_id], weight=4.5
)