diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index c3ba2c864..03433c379 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -1,8 +1,12 @@ from types import SimpleNamespace import pytest +import os from unittest.mock import AsyncMock, patch, MagicMock +from datetime import datetime from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp +from cognee.infrastructure.llm import LLMGateway # Test TemporalRetriever initialization defaults and overrides @@ -598,3 +602,78 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector assert isinstance(completion, list) assert len(completion) == 1 assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_extract_time_from_query_relative_path(): + """Test extract_time_from_query with relative prompt path.""" + retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") + + mock_timestamp_from = Timestamp(year=2024, month=1, day=1) + mock_timestamp_to = Timestamp(year=2024, month=12, day=31) + mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"), + patch.object( + LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") + + assert time_from == mock_timestamp_from + assert time_to == mock_timestamp_to + + +@pytest.mark.asyncio +async def test_extract_time_from_query_absolute_path(): + """Test extract_time_from_query with absolute prompt path.""" + retriever = TemporalRetriever(time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt") + + mock_timestamp_from = Timestamp(year=2024, month=1, day=1) + mock_timestamp_to = Timestamp(year=2024, month=12, day=31) + mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True), + patch("cognee.modules.retrieval.temporal_retriever.os.path.dirname", return_value="/absolute/path/to"), + patch("cognee.modules.retrieval.temporal_retriever.os.path.basename", return_value="extract_query_time.txt"), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"), + patch.object( + LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") + + assert time_from == mock_timestamp_from + assert time_to == mock_timestamp_to + + +@pytest.mark.asyncio +async def test_extract_time_from_query_with_none_values(): + """Test extract_time_from_query when interval has None values.""" + retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") + + mock_interval = QueryInterval(starts_at=None, ends_at=None) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"), + patch.object( + LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened?") + + assert time_from is None + assert time_to is None