feat: increasing coverage of temporal retriever

This commit is contained in:
hajdul88 2025-12-10 18:58:13 +01:00
parent 29f94e39b9
commit 21a84d3100

View file

@ -141,90 +141,6 @@ async def test_filter_top_k_events_error_handling():
await tr.filter_top_k_events([{}], [])
class _FakeRetriever(TemporalRetriever):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._calls = []
async def extract_time_from_query(self, query: str):
if "both" in query:
return "2024-01-01", "2024-12-31"
if "from_only" in query:
return "2024-01-01", None
if "to_only" in query:
return None, "2024-12-31"
return None, None
async def get_triplets(self, query: str):
self._calls.append(("get_triplets", query))
return [{"s": "a", "p": "b", "o": "c"}]
async def resolve_edges_to_text(self, triplets):
self._calls.append(("resolve_edges_to_text", len(triplets)))
return "edges->text"
async def _fake_graph_collect_ids(self, **kwargs):
return ["e1", "e2"]
async def _fake_graph_collect_events(self, ids):
return [
{
"events": [
{"id": "e1", "description": "E1"},
{"id": "e2", "description": "E2"},
{"id": "e3", "description": "E3"},
]
}
]
async def _fake_vector_embed(self, texts):
assert isinstance(texts, list) and texts
return [[0.0, 1.0, 2.0]]
async def _fake_vector_search(self, **kwargs):
return [
SimpleNamespace(payload={"id": "e2"}, score=0.05),
SimpleNamespace(payload={"id": "e1"}, score=0.10),
]
async def get_context(self, query: str):
time_from, time_to = await self.extract_time_from_query(query)
if not (time_from or time_to):
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets)
ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to)
relevant_events = await self._fake_graph_collect_events(ids)
_ = await self._fake_vector_embed([query])
vector_search_results = await self._fake_vector_search(
collection_name="Event_name", query_vector=[0.0], limit=0
)
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
return self.descriptions_to_string(top_k_events)
# Test get_context fallback to triplets when no time is extracted
@pytest.mark.asyncio
async def test_fake_get_context_falls_back_to_triplets_when_no_time():
tr = _FakeRetriever(top_k=2)
ctx = await tr.get_context("no_time")
assert ctx == "edges->text"
assert tr._calls[0][0] == "get_triplets"
assert tr._calls[1][0] == "resolve_edges_to_text"
# Test get_context when time is extracted and vector ranking is applied
@pytest.mark.asyncio
async def test_fake_get_context_with_time_filters_and_vector_ranking():
tr = _FakeRetriever(top_k=2)
ctx = await tr.get_context("both time")
assert ctx.startswith("E2")
assert "#####################" in ctx
assert "E1" in ctx and "E3" not in ctx
@pytest.fixture
def mock_graph_engine():
"""Create a mock graph engine."""
@ -343,3 +259,342 @@ async def test_get_context_no_events_found(mock_graph_engine):
assert context == "triplet text"
mock_get_triplets.assert_awaited_once_with("test query")
mock_resolve.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine):
"""Test get_context with only time_from."""
retriever = TemporalRetriever(top_k=5)
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
):
context = await retriever.get_context("What happened after 2024?")
assert isinstance(context, str)
assert "Event 1" in context
@pytest.mark.asyncio
async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
"""Test get_context with only time_to."""
retriever = TemporalRetriever(top_k=5)
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
):
context = await retriever.get_context("What happened before 2024?")
assert isinstance(context, str)
assert "Event 1" in context
@pytest.mark.asyncio
async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine):
"""Test get_completion retrieves context when not provided."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("What happened in 2024?")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_provided_context():
"""Test get_completion uses provided context."""
retriever = TemporalRetriever()
with (
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion("test query", context="Provided context")
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
@pytest.mark.asyncio
async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine):
"""Test get_completion with session caching enabled."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
mock_user = MagicMock()
mock_user.id = "test-user-id"
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_conversation_history",
return_value="Previous conversation",
),
patch(
"cognee.modules.retrieval.temporal_retriever.summarize_text",
return_value="Context summary",
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch(
"cognee.modules.retrieval.temporal_retriever.save_conversation_history",
) as mock_save,
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = mock_user
completion = await retriever.get_completion(
"What happened in 2024?", session_id="test_session"
)
assert isinstance(completion, list)
assert len(completion) == 1
assert completion[0] == "Generated answer"
mock_save.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine):
"""Test get_completion with session config but no user ID."""
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value="Generated answer",
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
):
mock_config = MagicMock()
mock_config.caching = True
mock_cache_config.return_value = mock_config
mock_session_user.get.return_value = None # No user
completion = await retriever.get_completion("What happened in 2024?")
assert isinstance(completion, list)
assert len(completion) == 1
@pytest.mark.asyncio
async def test_get_completion_context_retrieved_but_empty(mock_graph_engine):
"""Test get_completion when get_context returns empty string."""
retriever = TemporalRetriever()
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
) as mock_get_vector,
patch.object(retriever, "filter_top_k_events", return_value=[]),
):
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
mock_get_vector.return_value = mock_vector_engine
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": ""},
]
}
]
with pytest.raises((UnboundLocalError, NameError)):
await retriever.get_completion("test query")
@pytest.mark.asyncio
async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine):
"""Test get_completion with custom response model."""
from pydantic import BaseModel
class TestModel(BaseModel):
answer: str
retriever = TemporalRetriever()
mock_graph_engine.collect_time_ids.return_value = ["e1"]
mock_graph_engine.collect_events.return_value = [
{
"events": [
{"id": "e1", "description": "Event 1"},
]
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
patch.object(
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.temporal_retriever.generate_completion",
return_value=TestModel(answer="Test answer"),
),
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
):
mock_config = MagicMock()
mock_config.caching = False
mock_cache_config.return_value = mock_config
completion = await retriever.get_completion(
"What happened in 2024?", response_model=TestModel
)
assert isinstance(completion, list)
assert len(completion) == 1
assert isinstance(completion[0], TestModel)