feat: adds unit test to triplet retriever test without session and with provided context
This commit is contained in:
parent
ac58058eaa
commit
0329777290
1 changed files with 51 additions and 0 deletions
|
|
@ -81,3 +81,54 @@ async def test_get_context_collection_not_found_error(mock_vector_engine):
|
|||
):
|
||||
with pytest.raises(NoDataError, match="No data found"):
|
||||
await retriever.get_context("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_session(mock_vector_engine):
|
||||
"""Test get_completion without session caching."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Alice knows Bob"}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
), patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
), patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.CacheConfig"
|
||||
) as mock_cache_config:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_provided_context(mock_vector_engine):
|
||||
"""Test get_completion with provided context."""
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
), patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.CacheConfig"
|
||||
) as mock_cache_config:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context="Provided context")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue