diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py index 4b9c00e08..98bfd48fe 100644 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -118,3 +118,66 @@ async def test_get_completion_without_context(mock_vector_engine): assert len(completion) == 1 assert completion[0]["text"] == "Steve Rodger" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test ChunksRetriever initialization with defaults.""" + retriever = ChunksRetriever() + + assert retriever.top_k == 5 + + +@pytest.mark.asyncio +async def test_init_custom_top_k(): + """Test ChunksRetriever initialization with custom top_k.""" + retriever = ChunksRetriever(top_k=10) + + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test ChunksRetriever initialization with None top_k.""" + retriever = ChunksRetriever(top_k=None) + + assert retriever.top_k is None + + +@pytest.mark.asyncio +async def test_get_context_empty_payload(mock_vector_engine): + """Test get_context handles empty payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 1 + assert context[0] == {} + + +@pytest.mark.asyncio +async def test_get_completion_with_session_id(mock_vector_engine): + """Test get_completion with session_id parameter.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Steve Rodger"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", session_id="test_session") + + assert len(completion) == 1 + assert completion[0]["text"] == "Steve Rodger"