feat: increasing coverage in triplet retriever
This commit is contained in:
parent
9d9a388804
commit
29f94e39b9
1 changed files with 196 additions and 4 deletions
|
|
@ -84,10 +84,75 @@ async def test_get_context_collection_not_found_error(mock_vector_engine):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_completion_without_session(mock_vector_engine):
|
async def test_get_context_empty_payload_text(mock_vector_engine):
|
||||||
"""Test get_completion without session caching."""
|
"""Test get_context handles missing text in payload."""
|
||||||
mock_result = MagicMock()
|
mock_result = MagicMock()
|
||||||
mock_result.payload = {"text": "Alice knows Bob"}
|
mock_result.payload = {}
|
||||||
|
|
||||||
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
|
retriever = TripletRetriever()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||||
|
return_value=mock_vector_engine,
|
||||||
|
):
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await retriever.get_context("test query")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_context_single_triplet(mock_vector_engine):
|
||||||
|
"""Test get_context with single triplet result."""
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.payload = {"text": "Single triplet"}
|
||||||
|
|
||||||
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
|
retriever = TripletRetriever()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||||
|
return_value=mock_vector_engine,
|
||||||
|
):
|
||||||
|
context = await retriever.get_context("test query")
|
||||||
|
|
||||||
|
assert context == "Single triplet"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_defaults():
|
||||||
|
"""Test TripletRetriever initialization with defaults."""
|
||||||
|
retriever = TripletRetriever()
|
||||||
|
|
||||||
|
assert retriever.user_prompt_path == "context_for_question.txt"
|
||||||
|
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
||||||
|
assert retriever.top_k == 5 # Default is 5
|
||||||
|
assert retriever.system_prompt is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_custom_params():
|
||||||
|
"""Test TripletRetriever initialization with custom parameters."""
|
||||||
|
retriever = TripletRetriever(
|
||||||
|
user_prompt_path="custom_user.txt",
|
||||||
|
system_prompt_path="custom_system.txt",
|
||||||
|
system_prompt="Custom prompt",
|
||||||
|
top_k=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert retriever.user_prompt_path == "custom_user.txt"
|
||||||
|
assert retriever.system_prompt_path == "custom_system.txt"
|
||||||
|
assert retriever.system_prompt == "Custom prompt"
|
||||||
|
assert retriever.top_k == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_without_context(mock_vector_engine):
|
||||||
|
"""Test get_completion retrieves context when not provided."""
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.payload = {"text": "Test triplet"}
|
||||||
|
mock_vector_engine.has_collection.return_value = True
|
||||||
mock_vector_engine.search.return_value = [mock_result]
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
retriever = TripletRetriever()
|
retriever = TripletRetriever()
|
||||||
|
|
@ -116,7 +181,7 @@ async def test_get_completion_without_session(mock_vector_engine):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_completion_with_provided_context(mock_vector_engine):
|
async def test_get_completion_with_provided_context(mock_vector_engine):
|
||||||
"""Test get_completion with provided context."""
|
"""Test get_completion uses provided context."""
|
||||||
retriever = TripletRetriever()
|
retriever = TripletRetriever()
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
@ -135,3 +200,130 @@ async def test_get_completion_with_provided_context(mock_vector_engine):
|
||||||
assert isinstance(completion, list)
|
assert isinstance(completion, list)
|
||||||
assert len(completion) == 1
|
assert len(completion) == 1
|
||||||
assert completion[0] == "Generated answer"
|
assert completion[0] == "Generated answer"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_with_session(mock_vector_engine):
|
||||||
|
"""Test get_completion with session caching enabled."""
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.payload = {"text": "Test triplet"}
|
||||||
|
mock_vector_engine.has_collection.return_value = True
|
||||||
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
|
retriever = TripletRetriever()
|
||||||
|
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "test-user-id"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||||
|
return_value=mock_vector_engine,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.get_conversation_history",
|
||||||
|
return_value="Previous conversation",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.summarize_text",
|
||||||
|
return_value="Context summary",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.save_conversation_history",
|
||||||
|
) as mock_save,
|
||||||
|
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
||||||
|
patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user,
|
||||||
|
):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.caching = True
|
||||||
|
mock_cache_config.return_value = mock_config
|
||||||
|
mock_session_user.get.return_value = mock_user
|
||||||
|
|
||||||
|
completion = await retriever.get_completion("test query", session_id="test_session")
|
||||||
|
|
||||||
|
assert isinstance(completion, list)
|
||||||
|
assert len(completion) == 1
|
||||||
|
assert completion[0] == "Generated answer"
|
||||||
|
mock_save.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_with_session_no_user_id(mock_vector_engine):
|
||||||
|
"""Test get_completion with session config but no user ID."""
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.payload = {"text": "Test triplet"}
|
||||||
|
mock_vector_engine.has_collection.return_value = True
|
||||||
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
|
retriever = TripletRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||||
|
return_value=mock_vector_engine,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
||||||
|
return_value="Generated answer",
|
||||||
|
),
|
||||||
|
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
||||||
|
patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user,
|
||||||
|
):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.caching = True
|
||||||
|
mock_cache_config.return_value = mock_config
|
||||||
|
mock_session_user.get.return_value = None # No user
|
||||||
|
|
||||||
|
completion = await retriever.get_completion("test query")
|
||||||
|
|
||||||
|
assert isinstance(completion, list)
|
||||||
|
assert len(completion) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_completion_with_response_model(mock_vector_engine):
|
||||||
|
"""Test get_completion with custom response model."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class TestModel(BaseModel):
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.payload = {"text": "Test triplet"}
|
||||||
|
mock_vector_engine.has_collection.return_value = True
|
||||||
|
mock_vector_engine.search.return_value = [mock_result]
|
||||||
|
|
||||||
|
retriever = TripletRetriever()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||||
|
return_value=mock_vector_engine,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
||||||
|
return_value=TestModel(answer="Test answer"),
|
||||||
|
),
|
||||||
|
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
||||||
|
):
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.caching = False
|
||||||
|
mock_cache_config.return_value = mock_config
|
||||||
|
|
||||||
|
completion = await retriever.get_completion("test query", response_model=TestModel)
|
||||||
|
|
||||||
|
assert isinstance(completion, list)
|
||||||
|
assert len(completion) == 1
|
||||||
|
assert isinstance(completion[0], TestModel)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_none_top_k():
|
||||||
|
"""Test TripletRetriever initialization with None top_k."""
|
||||||
|
retriever = TripletRetriever(top_k=None)
|
||||||
|
|
||||||
|
assert retriever.top_k == 5
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue