feat: increasing the coverage of summaries retriever
This commit is contained in:
parent
21a84d3100
commit
ee00af9266
1 changed files with 73 additions and 0 deletions
|
|
@ -118,3 +118,76 @@ async def test_get_completion_without_context(mock_vector_engine):
|
|||
|
||||
assert len(completion) == 1
|
||||
assert completion[0]["text"] == "S.R."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_defaults():
|
||||
"""Test SummariesRetriever initialization with defaults."""
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
assert retriever.top_k == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_top_k():
|
||||
"""Test SummariesRetriever initialization with custom top_k."""
|
||||
retriever = SummariesRetriever(top_k=10)
|
||||
|
||||
assert retriever.top_k == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_payload(mock_vector_engine):
|
||||
"""Test get_context handles empty payload."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert len(context) == 1
|
||||
assert context[0] == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session_id(mock_vector_engine):
|
||||
"""Test get_completion with session_id parameter."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "S.R."}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
completion = await retriever.get_completion("test query", session_id="test_session")
|
||||
|
||||
assert len(completion) == 1
|
||||
assert completion[0]["text"] == "S.R."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_kwargs(mock_vector_engine):
|
||||
"""Test get_completion accepts additional kwargs."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "S.R."}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
completion = await retriever.get_completion("test query", extra_param="value")
|
||||
|
||||
assert len(completion) == 1
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue