adds 100% cov to temporal retriever

This commit is contained in:
hajdul88 2025-12-11 12:27:14 +01:00
parent 430df0db15
commit c5c60ccad0

View file

@ -1,8 +1,12 @@
from types import SimpleNamespace
import pytest
import os
from unittest.mock import AsyncMock, patch, MagicMock
from datetime import datetime
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp
from cognee.infrastructure.llm import LLMGateway
# Test TemporalRetriever initialization defaults and overrides
@ -598,3 +602,78 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)
@pytest.mark.asyncio
async def test_extract_time_from_query_relative_path():
"""Test extract_time_from_query with relative prompt path."""
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
patch.object(
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
assert time_from == mock_timestamp_from
assert time_to == mock_timestamp_to
@pytest.mark.asyncio
async def test_extract_time_from_query_absolute_path():
"""Test extract_time_from_query with absolute prompt path."""
retriever = TemporalRetriever(time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt")
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True),
patch("cognee.modules.retrieval.temporal_retriever.os.path.dirname", return_value="/absolute/path/to"),
patch("cognee.modules.retrieval.temporal_retriever.os.path.basename", return_value="extract_query_time.txt"),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
patch.object(
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
assert time_from == mock_timestamp_from
assert time_to == mock_timestamp_to
@pytest.mark.asyncio
async def test_extract_time_from_query_with_none_values():
"""Test extract_time_from_query when interval has None values."""
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
mock_interval = QueryInterval(starts_at=None, ends_at=None)
with (
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
patch.object(
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
),
):
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
time_from, time_to = await retriever.extract_time_from_query("What happened?")
assert time_from is None
assert time_to is None