adds 100% cov to temporal retriever
This commit is contained in:
parent
430df0db15
commit
c5c60ccad0
1 changed files with 79 additions and 0 deletions
|
|
@ -1,8 +1,12 @@
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
import pytest
|
import pytest
|
||||||
|
import os
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||||
|
from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp
|
||||||
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
|
||||||
|
|
||||||
# Test TemporalRetriever initialization defaults and overrides
|
# Test TemporalRetriever initialization defaults and overrides
|
||||||
|
|
@ -598,3 +602,78 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector
|
||||||
assert isinstance(completion, list)
|
assert isinstance(completion, list)
|
||||||
assert len(completion) == 1
|
assert len(completion) == 1
|
||||||
assert isinstance(completion[0], TestModel)
|
assert isinstance(completion[0], TestModel)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_time_from_query_relative_path():
|
||||||
|
"""Test extract_time_from_query with relative prompt path."""
|
||||||
|
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
|
||||||
|
|
||||||
|
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
|
||||||
|
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
|
||||||
|
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
|
||||||
|
patch.object(
|
||||||
|
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
|
||||||
|
),
|
||||||
|
):
|
||||||
|
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
||||||
|
|
||||||
|
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
|
||||||
|
|
||||||
|
assert time_from == mock_timestamp_from
|
||||||
|
assert time_to == mock_timestamp_to
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_time_from_query_absolute_path():
|
||||||
|
"""Test extract_time_from_query with absolute prompt path."""
|
||||||
|
retriever = TemporalRetriever(time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt")
|
||||||
|
|
||||||
|
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
|
||||||
|
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
|
||||||
|
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True),
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.os.path.dirname", return_value="/absolute/path/to"),
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.os.path.basename", return_value="extract_query_time.txt"),
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
|
||||||
|
patch.object(
|
||||||
|
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
|
||||||
|
),
|
||||||
|
):
|
||||||
|
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
||||||
|
|
||||||
|
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
|
||||||
|
|
||||||
|
assert time_from == mock_timestamp_from
|
||||||
|
assert time_to == mock_timestamp_to
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_time_from_query_with_none_values():
|
||||||
|
"""Test extract_time_from_query when interval has None values."""
|
||||||
|
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
|
||||||
|
|
||||||
|
mock_interval = QueryInterval(starts_at=None, ends_at=None)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
||||||
|
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
|
||||||
|
patch.object(
|
||||||
|
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
|
||||||
|
),
|
||||||
|
):
|
||||||
|
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
||||||
|
|
||||||
|
time_from, time_to = await retriever.extract_time_from_query("What happened?")
|
||||||
|
|
||||||
|
assert time_from is None
|
||||||
|
assert time_to is None
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue