adds 100% cov to temporal retriever
This commit is contained in:
parent
430df0db15
commit
c5c60ccad0
1 changed files with 79 additions and 0 deletions
|
|
@ -1,8 +1,12 @@
|
|||
from types import SimpleNamespace
|
||||
import pytest
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from datetime import datetime
|
||||
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
|
||||
|
||||
# Test TemporalRetriever initialization defaults and overrides
|
||||
|
|
@ -598,3 +602,78 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector
|
|||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert isinstance(completion[0], TestModel)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_time_from_query_relative_path():
|
||||
"""Test extract_time_from_query with relative prompt path."""
|
||||
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
|
||||
|
||||
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
|
||||
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
|
||||
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
|
||||
|
||||
with (
|
||||
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
||||
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
|
||||
patch.object(
|
||||
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
|
||||
),
|
||||
):
|
||||
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
||||
|
||||
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
|
||||
|
||||
assert time_from == mock_timestamp_from
|
||||
assert time_to == mock_timestamp_to
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_time_from_query_absolute_path():
|
||||
"""Test extract_time_from_query with absolute prompt path."""
|
||||
retriever = TemporalRetriever(time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt")
|
||||
|
||||
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
|
||||
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
|
||||
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
|
||||
|
||||
with (
|
||||
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.os.path.dirname", return_value="/absolute/path/to"),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.os.path.basename", return_value="extract_query_time.txt"),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
||||
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
|
||||
patch.object(
|
||||
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
|
||||
),
|
||||
):
|
||||
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
||||
|
||||
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
|
||||
|
||||
assert time_from == mock_timestamp_from
|
||||
assert time_to == mock_timestamp_to
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_time_from_query_with_none_values():
|
||||
"""Test extract_time_from_query when interval has None values."""
|
||||
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
|
||||
|
||||
mock_interval = QueryInterval(starts_at=None, ends_at=None)
|
||||
|
||||
with (
|
||||
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
||||
patch("cognee.modules.retrieval.temporal_retriever.render_prompt", return_value="System prompt"),
|
||||
patch.object(
|
||||
LLMGateway, "acreate_structured_output", new_callable=AsyncMock, return_value=mock_interval
|
||||
),
|
||||
):
|
||||
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
||||
|
||||
time_from, time_to = await retriever.extract_time_from_query("What happened?")
|
||||
|
||||
assert time_from is None
|
||||
assert time_to is None
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue