From 9d9a388804b1cee16db241739bd6e4224eaf2c3f Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 10 Dec 2025 18:31:03 +0100 Subject: [PATCH] feat: adds unit test for context extension retriever and chunks retriever --- .../retrieval/chunks_retriever_test.py | 120 ++++++++++++++++++ ...letion_retriever_context_extension_test.py | 53 ++++++++ 2 files changed, 173 insertions(+) create mode 100644 cognee/tests/unit/modules/retrieval/chunks_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py new file mode 100644 index 000000000..4b9c00e08 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -0,0 +1,120 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of chunk context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1} + + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + retriever = ChunksRetriever(top_k=5) + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 2 + assert context[0]["text"] == "Steve Rodger" + assert context[1]["text"] == "Mike Broski" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5) + + +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty list is returned when no chunks are found.""" + mock_vector_engine.search.return_value = [] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(3)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Chunk {i}"} + + mock_vector_engine.search.return_value = mock_results + + retriever = ChunksRetriever(top_k=3) + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 3 + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3) + + +@pytest.mark.asyncio +async def test_get_completion_with_context(mock_vector_engine): + """Test get_completion returns provided context.""" + retriever = ChunksRetriever() + + provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}] + completion = await retriever.get_completion("test query", context=provided_context) + + assert completion == provided_context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Steve Rodger"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query") + + assert len(completion) == 1 + assert completion[0]["text"] == "Steve Rodger" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py new file mode 100644 index 000000000..9d8ca3a79 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -0,0 +1,53 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +@pytest.mark.asyncio +async def test_get_triplets_inherited(mock_edge): + """Test that get_triplets is inherited from parent class.""" + retriever = GraphCompletionContextExtensionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ): + triplets = await retriever.get_triplets("test query") + + assert len(triplets) == 1 + assert triplets[0] == mock_edge + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionContextExtensionRetriever initialization with defaults.""" + retriever = GraphCompletionContextExtensionRetriever() + + assert retriever.top_k == 5 + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionContextExtensionRetriever initialization with custom parameters.""" + retriever = GraphCompletionContextExtensionRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + ) + + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt"