From 21a84d3100d34d1b75c675181489ae16afe00ef1 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 10 Dec 2025 18:58:13 +0100 Subject: [PATCH] feat: increasing coverage of temporal retriever --- .../retrieval/temporal_retriever_test.py | 423 ++++++++++++++---- 1 file changed, 339 insertions(+), 84 deletions(-) diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index fe5908847..c3ba2c864 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -141,90 +141,6 @@ async def test_filter_top_k_events_error_handling(): await tr.filter_top_k_events([{}], []) -class _FakeRetriever(TemporalRetriever): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._calls = [] - - async def extract_time_from_query(self, query: str): - if "both" in query: - return "2024-01-01", "2024-12-31" - if "from_only" in query: - return "2024-01-01", None - if "to_only" in query: - return None, "2024-12-31" - return None, None - - async def get_triplets(self, query: str): - self._calls.append(("get_triplets", query)) - return [{"s": "a", "p": "b", "o": "c"}] - - async def resolve_edges_to_text(self, triplets): - self._calls.append(("resolve_edges_to_text", len(triplets))) - return "edges->text" - - async def _fake_graph_collect_ids(self, **kwargs): - return ["e1", "e2"] - - async def _fake_graph_collect_events(self, ids): - return [ - { - "events": [ - {"id": "e1", "description": "E1"}, - {"id": "e2", "description": "E2"}, - {"id": "e3", "description": "E3"}, - ] - } - ] - - async def _fake_vector_embed(self, texts): - assert isinstance(texts, list) and texts - return [[0.0, 1.0, 2.0]] - - async def _fake_vector_search(self, **kwargs): - return [ - SimpleNamespace(payload={"id": "e2"}, score=0.05), - SimpleNamespace(payload={"id": "e1"}, score=0.10), - ] - - async def get_context(self, query: str): - time_from, time_to = await self.extract_time_from_query(query) - - if not (time_from or time_to): - triplets = await self.get_triplets(query) - return await self.resolve_edges_to_text(triplets) - - ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to) - relevant_events = await self._fake_graph_collect_events(ids) - - _ = await self._fake_vector_embed([query]) - vector_search_results = await self._fake_vector_search( - collection_name="Event_name", query_vector=[0.0], limit=0 - ) - top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results) - return self.descriptions_to_string(top_k_events) - - -# Test get_context fallback to triplets when no time is extracted -@pytest.mark.asyncio -async def test_fake_get_context_falls_back_to_triplets_when_no_time(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("no_time") - assert ctx == "edges->text" - assert tr._calls[0][0] == "get_triplets" - assert tr._calls[1][0] == "resolve_edges_to_text" - - -# Test get_context when time is extracted and vector ranking is applied -@pytest.mark.asyncio -async def test_fake_get_context_with_time_filters_and_vector_ranking(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("both time") - assert ctx.startswith("E2") - assert "#####################" in ctx - assert "E1" in ctx and "E3" not in ctx - - @pytest.fixture def mock_graph_engine(): """Create a mock graph engine.""" @@ -343,3 +259,342 @@ async def test_get_context_no_events_found(mock_graph_engine): assert context == "triplet text" mock_get_triplets.assert_awaited_once_with("test query") mock_resolve.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine): + """Test get_context with only time_from.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened after 2024?") + + assert isinstance(context, str) + assert "Event 1" in context + + +@pytest.mark.asyncio +async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine): + """Test get_context with only time_to.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened before 2024?") + + assert isinstance(context, str) + assert "Event 1" in context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("What happened in 2024?") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(): + """Test get_completion uses provided context.""" + retriever = TemporalRetriever() + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine): + """Test get_completion with session caching enabled.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "What happened in 2024?", session_id="test_session" + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine): + """Test get_completion with session config but no user ID.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("What happened in 2024?") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_context_retrieved_but_empty(mock_graph_engine): + """Test get_completion when get_context returns empty string.""" + retriever = TemporalRetriever() + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + ) as mock_get_vector, + patch.object(retriever, "filter_top_k_events", return_value=[]), + ): + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + mock_get_vector.return_value = mock_vector_engine + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": ""}, + ] + } + ] + + with pytest.raises((UnboundLocalError, NameError)): + await retriever.get_completion("test query") + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "What happened in 2024?", response_model=TestModel + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel)