feat: adds additional unit tests for temporal retriever
This commit is contained in:
parent
1fef4c1ab3
commit
ab6a1d1b5b
1 changed files with 107 additions and 0 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from types import SimpleNamespace
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
|
||||
|
|
@ -222,3 +223,109 @@ async def test_fake_get_context_with_time_filters_and_vector_ranking():
|
|||
assert ctx.startswith("E2")
|
||||
assert "#####################" in ctx
|
||||
assert "E1" in ctx and "E3" not in ctx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_engine():
|
||||
"""Create a mock graph engine."""
|
||||
engine = AsyncMock()
|
||||
engine.collect_time_ids = AsyncMock()
|
||||
engine.collect_events = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.embedding_engine = AsyncMock()
|
||||
engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_context when time range is extracted from query."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
{"id": "e2", "description": "Event 2"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
|
||||
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
|
||||
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
||||
|
||||
with patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
), patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
), patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("What happened in 2024?")
|
||||
|
||||
assert isinstance(context, str)
|
||||
assert len(context) > 0
|
||||
assert "Event" in context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine):
|
||||
"""Test get_context falls back to triplets when no time is extracted."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
), patch.object(
|
||||
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
|
||||
) as mock_get_triplets, patch.object(
|
||||
retriever, "resolve_edges_to_text", return_value="triplet text"
|
||||
) as mock_resolve:
|
||||
async def mock_extract_time(query):
|
||||
return None, None
|
||||
|
||||
retriever.extract_time_from_query = mock_extract_time
|
||||
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "triplet text"
|
||||
mock_get_triplets.assert_awaited_once_with("test query")
|
||||
mock_resolve.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_no_events_found(mock_graph_engine):
|
||||
"""Test get_context falls back to triplets when no events are found."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = []
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
), patch.object(
|
||||
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
|
||||
) as mock_get_triplets, patch.object(
|
||||
retriever, "resolve_edges_to_text", return_value="triplet text"
|
||||
) as mock_resolve:
|
||||
async def mock_extract_time(query):
|
||||
return "2024-01-01", "2024-12-31"
|
||||
|
||||
retriever.extract_time_from_query = mock_extract_time
|
||||
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "triplet text"
|
||||
mock_get_triplets.assert_awaited_once_with("test query")
|
||||
mock_resolve.assert_awaited_once()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue