test fix
This commit is contained in:
parent
16b073bf8c
commit
47cce90112
1 changed files with 36 additions and 5 deletions
|
|
@ -1,6 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
from cognee.modules.users.methods import get_default_user
|
|
||||||
from cognee.context_global_variables import session_user
|
from cognee.context_global_variables import session_user
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,13 +16,20 @@ def create_mock_cache_engine(qa_history=None):
|
||||||
return mock_cache
|
return mock_cache
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_user():
|
||||||
|
"""Create a mock user without database access"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "test-user-id-123"
|
||||||
|
return mock_user
|
||||||
|
|
||||||
|
|
||||||
class TestConversationHistoryUtils:
|
class TestConversationHistoryUtils:
|
||||||
"""Test the two utility functions: get_conversation_history and save_to_session_cache"""
|
"""Test the two utility functions: get_conversation_history and save_to_session_cache"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_conversation_history_returns_empty_when_no_history(self):
|
async def test_get_conversation_history_returns_empty_when_no_history(self):
|
||||||
"""Test get_conversation_history returns empty string when no history exists."""
|
"""Test get_conversation_history returns empty string when no history exists."""
|
||||||
user = await get_default_user()
|
user = create_mock_user()
|
||||||
session_user.set(user)
|
session_user.set(user)
|
||||||
|
|
||||||
mock_cache = create_mock_cache_engine([])
|
mock_cache = create_mock_cache_engine([])
|
||||||
|
|
@ -41,7 +47,7 @@ class TestConversationHistoryUtils:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_conversation_history_formats_history_correctly(self):
|
async def test_get_conversation_history_formats_history_correctly(self):
|
||||||
"""Test get_conversation_history formats Q&A history with correct structure."""
|
"""Test get_conversation_history formats Q&A history with correct structure."""
|
||||||
user = await get_default_user()
|
user = create_mock_user()
|
||||||
session_user.set(user)
|
session_user.set(user)
|
||||||
|
|
||||||
mock_history = [
|
mock_history = [
|
||||||
|
|
@ -71,7 +77,7 @@ class TestConversationHistoryUtils:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_save_to_session_cache_saves_correctly(self):
|
async def test_save_to_session_cache_saves_correctly(self):
|
||||||
"""Test save_to_session_cache calls add_qa with correct parameters."""
|
"""Test save_to_session_cache calls add_qa with correct parameters."""
|
||||||
user = await get_default_user()
|
user = create_mock_user()
|
||||||
session_user.set(user)
|
session_user.set(user)
|
||||||
|
|
||||||
mock_cache = create_mock_cache_engine([])
|
mock_cache = create_mock_cache_engine([])
|
||||||
|
|
@ -97,3 +103,28 @@ class TestConversationHistoryUtils:
|
||||||
assert call_kwargs["context"] == "Python is a programming language"
|
assert call_kwargs["context"] == "Python is a programming language"
|
||||||
assert call_kwargs["answer"] == "Python is a high-level programming language"
|
assert call_kwargs["answer"] == "Python is a high-level programming language"
|
||||||
assert call_kwargs["session_id"] == "my_session"
|
assert call_kwargs["session_id"] == "my_session"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_to_session_cache_uses_default_session_when_none(self):
|
||||||
|
"""Test save_to_session_cache uses 'default_session' when session_id is None."""
|
||||||
|
user = create_mock_user()
|
||||||
|
session_user.set(user)
|
||||||
|
|
||||||
|
mock_cache = create_mock_cache_engine([])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"cognee.infrastructure.databases.cache.get_cache_engine.get_cache_engine",
|
||||||
|
return_value=mock_cache,
|
||||||
|
):
|
||||||
|
from cognee.modules.retrieval.utils.session_cache import save_to_session_cache
|
||||||
|
|
||||||
|
result = await save_to_session_cache(
|
||||||
|
query="Test question",
|
||||||
|
context_summary="Test context",
|
||||||
|
answer="Test answer",
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
call_kwargs = mock_cache.add_qa.call_args.kwargs
|
||||||
|
assert call_kwargs["session_id"] == "default_session"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue