From ef042d9a29ec0c6d46074cb61afa4e0d5487372f Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 11 Dec 2025 09:49:42 +0100 Subject: [PATCH] feat: increases coverage for rag completion retriever --- .../rag_completion_retriever_test.py | 161 ++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index 3c004db5d..e998d419d 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -158,3 +158,164 @@ async def test_get_completion_with_provided_context(mock_vector_engine): assert isinstance(completion, list) assert len(completion) == 1 assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_vector_engine): + """Test get_completion with session caching enabled.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.completion_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.completion_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_vector_engine): + """Test get_completion with session config but no user ID.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test CompletionRetriever initialization with defaults.""" + retriever = CompletionRetriever() + + assert retriever.user_prompt_path == "context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 1 + assert retriever.system_prompt is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test CompletionRetriever initialization with custom parameters.""" + retriever = CompletionRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + top_k=10, + ) + + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_context_missing_text_key(mock_vector_engine): + """Test get_context handles missing text key in payload.""" + mock_result = MagicMock() + mock_result.payload = {"other_key": "value"} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(KeyError): + await retriever.get_context("test query")