feat: adds additional unit tests for temporal retriever

This commit is contained in:
hajdul88 2025-12-10 17:57:40 +01:00
parent 1fef4c1ab3
commit ab6a1d1b5b

View file

@ -1,5 +1,6 @@
from types import SimpleNamespace
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
@ -222,3 +223,109 @@ async def test_fake_get_context_with_time_filters_and_vector_ranking():
assert ctx.startswith("E2")
assert "#####################" in ctx
assert "E1" in ctx and "E3" not in ctx
@pytest.fixture
def mock_graph_engine():
"""Create a mock graph engine."""
engine = AsyncMock()
engine.collect_time_ids = AsyncMock()
engine.collect_events = AsyncMock()
return engine
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.embedding_engine = AsyncMock()
engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
engine.search = AsyncMock()
return engine
@pytest.mark.asyncio
async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine):
"""Test get_context when time range is extracted from query."""
retriever = TemporalRetriever(top_k=5)
mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
{"id": "e2", "description": "Event 2"},
]
}
]
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
with patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
), patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
), patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("What happened in 2024?")
assert isinstance(context, str)
assert len(context) > 0
assert "Event" in context
@pytest.mark.asyncio
async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine):
"""Test get_context falls back to triplets when no time is extracted."""
retriever = TemporalRetriever()
with patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
), patch.object(
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
) as mock_get_triplets, patch.object(
retriever, "resolve_edges_to_text", return_value="triplet text"
) as mock_resolve:
async def mock_extract_time(query):
return None, None
retriever.extract_time_from_query = mock_extract_time
context = await retriever.get_context("test query")
assert context == "triplet text"
mock_get_triplets.assert_awaited_once_with("test query")
mock_resolve.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_context_no_events_found(mock_graph_engine):
"""Test get_context falls back to triplets when no events are found."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = []
with patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
), patch.object(
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
) as mock_get_triplets, patch.object(
retriever, "resolve_edges_to_text", return_value="triplet text"
) as mock_resolve:
async def mock_extract_time(query):
return "2024-01-01", "2024-12-31"
retriever.extract_time_from_query = mock_extract_time
context = await retriever.get_context("test query")
assert context == "triplet text"
mock_get_triplets.assert_awaited_once_with("test query")
mock_resolve.assert_awaited_once()