From 9d900f48cd5f416bcfe0eaaf6a0122961b83c878 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 10 Dec 2025 18:28:11 +0100 Subject: [PATCH] feat: adds unit test for cot retriever --- .../graph_completion_retriever_cot_test.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py new file mode 100644 index 000000000..d3df83516 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -0,0 +1,51 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +@pytest.mark.asyncio +async def test_get_triplets_inherited(mock_edge): + """Test that get_triplets is inherited from parent class.""" + retriever = GraphCompletionCotRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ): + triplets = await retriever.get_triplets("test query") + + assert len(triplets) == 1 + assert triplets[0] == mock_edge + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionCotRetriever initialization with defaults.""" + retriever = GraphCompletionCotRetriever() + + assert retriever.top_k == 5 + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionCotRetriever initialization with custom parameters.""" + retriever = GraphCompletionCotRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + ) + + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt"