diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py index d79aca428..3b2af36c2 100644 --- a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py @@ -81,3 +81,54 @@ async def test_get_context_collection_not_found_error(mock_vector_engine): ): with pytest.raises(NoDataError, match="No data found"): await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_completion_without_session(mock_vector_engine): + """Test get_completion without session caching.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Alice knows Bob"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), patch( + "cognee.modules.retrieval.triplet_retriever.CacheConfig" + ) as mock_cache_config: + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_vector_engine): + """Test get_completion with provided context.""" + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), patch( + "cognee.modules.retrieval.triplet_retriever.CacheConfig" + ) as mock_cache_config: + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer"