From ab6a1d1b5b9003b7c7111a9b851b054d1b59facc Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 10 Dec 2025 17:57:40 +0100 Subject: [PATCH] feat: adds additional unit tests for temporal retriever --- .../retrieval/temporal_retriever_test.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index c3c6a47f6..4f2171e02 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -1,5 +1,6 @@ from types import SimpleNamespace import pytest +from unittest.mock import AsyncMock, patch, MagicMock from cognee.modules.retrieval.temporal_retriever import TemporalRetriever @@ -222,3 +223,109 @@ async def test_fake_get_context_with_time_filters_and_vector_ranking(): assert ctx.startswith("E2") assert "#####################" in ctx assert "E1" in ctx and "E3" not in ctx + + +@pytest.fixture +def mock_graph_engine(): + """Create a mock graph engine.""" + engine = AsyncMock() + engine.collect_time_ids = AsyncMock() + engine.collect_events = AsyncMock() + return engine + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.embedding_engine = AsyncMock() + engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine): + """Test get_context when time range is extracted from query.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + {"id": "e2", "description": "Event 2"}, + ] + } + ] + + mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05) + mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10) + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + with patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("What happened in 2024?") + + assert isinstance(context, str) + assert len(context) > 0 + assert "Event" in context + + +@pytest.mark.asyncio +async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine): + """Test get_context falls back to triplets when no time is extracted.""" + retriever = TemporalRetriever() + + with patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), patch.object( + retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] + ) as mock_get_triplets, patch.object( + retriever, "resolve_edges_to_text", return_value="triplet text" + ) as mock_resolve: + async def mock_extract_time(query): + return None, None + + retriever.extract_time_from_query = mock_extract_time + + context = await retriever.get_context("test query") + + assert context == "triplet text" + mock_get_triplets.assert_awaited_once_with("test query") + mock_resolve.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_context_no_events_found(mock_graph_engine): + """Test get_context falls back to triplets when no events are found.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = [] + + with patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), patch.object( + retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] + ) as mock_get_triplets, patch.object( + retriever, "resolve_edges_to_text", return_value="triplet text" + ) as mock_resolve: + async def mock_extract_time(query): + return "2024-01-01", "2024-12-31" + + retriever.extract_time_from_query = mock_extract_time + + context = await retriever.get_context("test query") + + assert context == "triplet text" + mock_get_triplets.assert_awaited_once_with("test query") + mock_resolve.assert_awaited_once()