feat: increasing the coverage of summaries retriever

This commit is contained in:
hajdul88 2025-12-10 19:04:28 +01:00
parent 21a84d3100
commit ee00af9266

View file

@ -118,3 +118,76 @@ async def test_get_completion_without_context(mock_vector_engine):
assert len(completion) == 1
assert completion[0]["text"] == "S.R."
@pytest.mark.asyncio
async def test_init_defaults():
"""Test SummariesRetriever initialization with defaults."""
retriever = SummariesRetriever()
assert retriever.top_k == 5
@pytest.mark.asyncio
async def test_init_custom_top_k():
"""Test SummariesRetriever initialization with custom top_k."""
retriever = SummariesRetriever(top_k=10)
assert retriever.top_k == 10
@pytest.mark.asyncio
async def test_get_context_empty_payload(mock_vector_engine):
"""Test get_context handles empty payload."""
mock_result = MagicMock()
mock_result.payload = {}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert len(context) == 1
assert context[0] == {}
@pytest.mark.asyncio
async def test_get_completion_with_session_id(mock_vector_engine):
"""Test get_completion with session_id parameter."""
mock_result = MagicMock()
mock_result.payload = {"text": "S.R."}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query", session_id="test_session")
assert len(completion) == 1
assert completion[0]["text"] == "S.R."
@pytest.mark.asyncio
async def test_get_completion_with_kwargs(mock_vector_engine):
"""Test get_completion accepts additional kwargs."""
mock_result = MagicMock()
mock_result.payload = {"text": "S.R."}
mock_vector_engine.search.return_value = [mock_result]
retriever = SummariesRetriever()
with patch(
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
completion = await retriever.get_completion("test query", extra_param="value")
assert len(completion) == 1