adds 100% cov to temporal retriever

This commit is contained in:
hajdul88 2025-12-11 12:27:14 +01:00
parent 430df0db15
commit c5c60ccad0

View file

@ -1,8 +1,12 @@
from types import SimpleNamespace from types import SimpleNamespace
import pytest import pytest
import os
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch, MagicMock
from datetime import datetime
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp
from cognee.infrastructure.llm import LLMGateway
# Test TemporalRetriever initialization defaults and overrides # Test TemporalRetriever initialization defaults and overrides
@ -598,3 +602,78 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector
assert isinstance(completion, list) assert isinstance(completion, list)
assert len(completion) == 1 assert len(completion) == 1
assert isinstance(completion[0], TestModel) assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_extract_time_from_query_relative_path():
"""Test extract_time_from_query with relative prompt path."""
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
patch.object(
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
assert time_from == mock_timestamp_from
assert time_to == mock_timestamp_to
@pytest.mark.asyncio
async def test_extract_time_from_query_absolute_path():
"""Test extract_time_from_query with absolute prompt path."""
retriever = TemporalRetriever(time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt")
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True),
patch("cognee.modules.retrieval.temporal_retriever.os.path.dirname", return_value="/absolute/path/to"),
patch("cognee.modules.retrieval.temporal_retriever.os.path.basename", return_value="extract_query_time.txt"),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
patch.object(
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
assert time_from == mock_timestamp_from
assert time_to == mock_timestamp_to
@pytest.mark.asyncio
async def test_extract_time_from_query_with_none_values():
"""Test extract_time_from_query when interval has None values."""
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
mock_interval = QueryInterval(starts_at=None, ends_at=None)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
patch.object(
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened?")
assert time_from is None
assert time_to is None