From 29f94e39b98296bebdc3f840925022a919cd88ef Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 10 Dec 2025 18:53:00 +0100 Subject: [PATCH] feat: increasing coverage in triplet retriever --- .../retrieval/triplet_retriever_test.py | 200 +++++++++++++++++- 1 file changed, 196 insertions(+), 4 deletions(-) diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py index 59a142e64..83612c7aa 100644 --- a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py @@ -84,10 +84,75 @@ async def test_get_context_collection_not_found_error(mock_vector_engine): @pytest.mark.asyncio -async def test_get_completion_without_session(mock_vector_engine): - """Test get_completion without session caching.""" +async def test_get_context_empty_payload_text(mock_vector_engine): + """Test get_context handles missing text in payload.""" mock_result = MagicMock() - mock_result.payload = {"text": "Alice knows Bob"} + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(KeyError): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_single_triplet(mock_vector_engine): + """Test get_context with single triplet result.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Single triplet"} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Single triplet" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test TripletRetriever initialization with defaults.""" + retriever = TripletRetriever() + + assert retriever.user_prompt_path == "context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 5 # Default is 5 + assert retriever.system_prompt is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test TripletRetriever initialization with custom parameters.""" + retriever = TripletRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + top_k=10, + ) + + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True mock_vector_engine.search.return_value = [mock_result] retriever = TripletRetriever() @@ -116,7 +181,7 @@ async def test_get_completion_without_session(mock_vector_engine): @pytest.mark.asyncio async def test_get_completion_with_provided_context(mock_vector_engine): - """Test get_completion with provided context.""" + """Test get_completion uses provided context.""" retriever = TripletRetriever() with ( @@ -135,3 +200,130 @@ async def test_get_completion_with_provided_context(mock_vector_engine): assert isinstance(completion, list) assert len(completion) == 1 assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_vector_engine): + """Test get_completion with session caching enabled.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_vector_engine): + """Test get_completion with session config but no user ID.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test TripletRetriever initialization with None top_k.""" + retriever = TripletRetriever(top_k=None) + + assert retriever.top_k == 5