Merge branch 'dev' into feature/cog-3213-docs-set-up-guide-script-tests
This commit is contained in:
commit
929d88557e
13 changed files with 5657 additions and 0 deletions
183
cognee/tests/unit/modules/retrieval/chunks_retriever_test.py
Normal file
183
cognee/tests/unit/modules/retrieval/chunks_retriever_test.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_success(mock_vector_engine):
|
||||
"""Test successful retrieval of chunk context."""
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0}
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
||||
|
||||
retriever = ChunksRetriever(top_k=5)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert len(context) == 2
|
||||
assert context[0]["text"] == "Steve Rodger"
|
||||
assert context[1]["text"] == "Mike Broski"
|
||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_collection_not_found_error(mock_vector_engine):
|
||||
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
||||
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
|
||||
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(NoDataError, match="No data found"):
|
||||
await retriever.get_context("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_results(mock_vector_engine):
|
||||
"""Test that empty list is returned when no chunks are found."""
|
||||
mock_vector_engine.search.return_value = []
|
||||
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_top_k_limit(mock_vector_engine):
|
||||
"""Test that top_k parameter limits the number of results."""
|
||||
mock_results = [MagicMock() for _ in range(3)]
|
||||
for i, result in enumerate(mock_results):
|
||||
result.payload = {"text": f"Chunk {i}"}
|
||||
|
||||
mock_vector_engine.search.return_value = mock_results
|
||||
|
||||
retriever = ChunksRetriever(top_k=3)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert len(context) == 3
|
||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_context(mock_vector_engine):
|
||||
"""Test get_completion returns provided context."""
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}]
|
||||
completion = await retriever.get_completion("test query", context=provided_context)
|
||||
|
||||
assert completion == provided_context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_context(mock_vector_engine):
|
||||
"""Test get_completion retrieves context when not provided."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Steve Rodger"}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
completion = await retriever.get_completion("test query")
|
||||
|
||||
assert len(completion) == 1
|
||||
assert completion[0]["text"] == "Steve Rodger"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_defaults():
|
||||
"""Test ChunksRetriever initialization with defaults."""
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
assert retriever.top_k == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_top_k():
|
||||
"""Test ChunksRetriever initialization with custom top_k."""
|
||||
retriever = ChunksRetriever(top_k=10)
|
||||
|
||||
assert retriever.top_k == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_none_top_k():
|
||||
"""Test ChunksRetriever initialization with None top_k."""
|
||||
retriever = ChunksRetriever(top_k=None)
|
||||
|
||||
assert retriever.top_k is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_payload(mock_vector_engine):
|
||||
"""Test get_context handles empty payload."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert len(context) == 1
|
||||
assert context[0] == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session_id(mock_vector_engine):
|
||||
"""Test get_completion with session_id parameter."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Steve Rodger"}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = ChunksRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.chunks_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
completion = await retriever.get_completion("test query", session_id="test_session")
|
||||
|
||||
assert len(completion) == 1
|
||||
assert completion[0]["text"] == "Steve Rodger"
|
||||
492
cognee/tests/unit/modules/retrieval/conversation_history_test.py
Normal file
492
cognee/tests/unit/modules/retrieval/conversation_history_test.py
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from cognee.context_global_variables import session_user
|
||||
import importlib
|
||||
|
||||
|
||||
def create_mock_cache_engine(qa_history=None):
|
||||
mock_cache = AsyncMock()
|
||||
if qa_history is None:
|
||||
qa_history = []
|
||||
mock_cache.get_latest_qa = AsyncMock(return_value=qa_history)
|
||||
mock_cache.add_qa = AsyncMock(return_value=None)
|
||||
return mock_cache
|
||||
|
||||
|
||||
def create_mock_user():
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-id-123"
|
||||
return mock_user
|
||||
|
||||
|
||||
class TestConversationHistoryUtils:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_returns_empty_when_no_history(self):
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
from cognee.modules.retrieval.utils.session_cache import get_conversation_history
|
||||
|
||||
result = await get_conversation_history(session_id="test_session")
|
||||
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_formats_history_correctly(self):
|
||||
"""Test get_conversation_history formats Q&A history with correct structure."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
mock_history = [
|
||||
{
|
||||
"time": "2024-01-15 10:30:45",
|
||||
"question": "What is AI?",
|
||||
"context": "AI is artificial intelligence",
|
||||
"answer": "AI stands for Artificial Intelligence",
|
||||
}
|
||||
]
|
||||
mock_cache = create_mock_cache_engine(mock_history)
|
||||
|
||||
# Import the real module to patch safely
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
||||
result = await get_conversation_history(session_id="test_session")
|
||||
|
||||
assert "Previous conversation:" in result
|
||||
assert "[2024-01-15 10:30:45]" in result
|
||||
assert "QUESTION: What is AI?" in result
|
||||
assert "CONTEXT: AI is artificial intelligence" in result
|
||||
assert "ANSWER: AI stands for Artificial Intelligence" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_to_session_cache_saves_correctly(self):
|
||||
"""Test save_conversation_history calls add_qa with correct parameters."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
)
|
||||
|
||||
result = await save_conversation_history(
|
||||
query="What is Python?",
|
||||
context_summary="Python is a programming language",
|
||||
answer="Python is a high-level programming language",
|
||||
session_id="my_session",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_cache.add_qa.assert_called_once()
|
||||
|
||||
call_kwargs = mock_cache.add_qa.call_args.kwargs
|
||||
assert call_kwargs["question"] == "What is Python?"
|
||||
assert call_kwargs["context"] == "Python is a programming language"
|
||||
assert call_kwargs["answer"] == "Python is a high-level programming language"
|
||||
assert call_kwargs["session_id"] == "my_session"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_to_session_cache_uses_default_session_when_none(self):
|
||||
"""Test save_conversation_history uses 'default_session' when session_id is None."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
)
|
||||
|
||||
result = await save_conversation_history(
|
||||
query="Test question",
|
||||
context_summary="Test context",
|
||||
answer="Test answer",
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
call_kwargs = mock_cache.add_qa.call_args.kwargs
|
||||
assert call_kwargs["session_id"] == "default_session"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_conversation_history_no_user_id(self):
|
||||
"""Test save_conversation_history returns False when user_id is None."""
|
||||
session_user.set(None)
|
||||
|
||||
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
)
|
||||
|
||||
result = await save_conversation_history(
|
||||
query="Test question",
|
||||
context_summary="Test context",
|
||||
answer="Test answer",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_conversation_history_caching_disabled(self):
|
||||
"""Test save_conversation_history returns False when caching is disabled."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
)
|
||||
|
||||
result = await save_conversation_history(
|
||||
query="Test question",
|
||||
context_summary="Test context",
|
||||
answer="Test answer",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_conversation_history_cache_engine_none(self):
|
||||
"""Test save_conversation_history returns False when cache_engine is None."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=None):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
)
|
||||
|
||||
result = await save_conversation_history(
|
||||
query="Test question",
|
||||
context_summary="Test context",
|
||||
answer="Test answer",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_conversation_history_cache_connection_error(self):
|
||||
"""Test save_conversation_history handles CacheConnectionError gracefully."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import CacheConnectionError
|
||||
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
mock_cache.add_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed"))
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
)
|
||||
|
||||
result = await save_conversation_history(
|
||||
query="Test question",
|
||||
context_summary="Test context",
|
||||
answer="Test answer",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_conversation_history_generic_exception(self):
|
||||
"""Test save_conversation_history handles generic exceptions gracefully."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
mock_cache.add_qa = AsyncMock(side_effect=ValueError("Unexpected error"))
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
save_conversation_history,
|
||||
)
|
||||
|
||||
result = await save_conversation_history(
|
||||
query="Test question",
|
||||
context_summary="Test context",
|
||||
answer="Test answer",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_no_user_id(self):
|
||||
"""Test get_conversation_history returns empty string when user_id is None."""
|
||||
session_user.set(None)
|
||||
|
||||
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
||||
result = await get_conversation_history(session_id="test_session")
|
||||
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_caching_disabled(self):
|
||||
"""Test get_conversation_history returns empty string when caching is disabled."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
||||
result = await get_conversation_history(session_id="test_session")
|
||||
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_default_session(self):
|
||||
"""Test get_conversation_history uses 'default_session' when session_id is None."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
||||
await get_conversation_history(session_id=None)
|
||||
|
||||
mock_cache.get_latest_qa.assert_called_once_with(str(user.id), "default_session")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_cache_engine_none(self):
|
||||
"""Test get_conversation_history returns empty string when cache_engine is None."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=None):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
||||
result = await get_conversation_history(session_id="test_session")
|
||||
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_cache_connection_error(self):
|
||||
"""Test get_conversation_history handles CacheConnectionError gracefully."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import CacheConnectionError
|
||||
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
mock_cache.get_latest_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed"))
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
||||
result = await get_conversation_history(session_id="test_session")
|
||||
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_generic_exception(self):
|
||||
"""Test get_conversation_history handles generic exceptions gracefully."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
mock_cache = create_mock_cache_engine([])
|
||||
mock_cache.get_latest_qa = AsyncMock(side_effect=ValueError("Unexpected error"))
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
||||
result = await get_conversation_history(session_id="test_session")
|
||||
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_missing_keys(self):
|
||||
"""Test get_conversation_history handles missing keys in history entries."""
|
||||
user = create_mock_user()
|
||||
session_user.set(user)
|
||||
|
||||
mock_history = [
|
||||
{
|
||||
"time": "2024-01-15 10:30:45",
|
||||
"question": "What is AI?",
|
||||
},
|
||||
{
|
||||
"context": "AI is artificial intelligence",
|
||||
"answer": "AI stands for Artificial Intelligence",
|
||||
},
|
||||
{},
|
||||
]
|
||||
mock_cache = create_mock_cache_engine(mock_history)
|
||||
|
||||
cache_module = importlib.import_module(
|
||||
"cognee.infrastructure.databases.cache.get_cache_engine"
|
||||
)
|
||||
|
||||
with patch.object(cache_module, "get_cache_engine", return_value=mock_cache):
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.session_cache.CacheConfig"
|
||||
) as MockCacheConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
MockCacheConfig.return_value = mock_config
|
||||
|
||||
from cognee.modules.retrieval.utils.session_cache import (
|
||||
get_conversation_history,
|
||||
)
|
||||
|
||||
result = await get_conversation_history(session_id="test_session")
|
||||
|
||||
assert "Previous conversation:" in result
|
||||
assert "[2024-01-15 10:30:45]" in result
|
||||
assert "QUESTION: What is AI?" in result
|
||||
assert "Unknown time" in result
|
||||
assert "CONTEXT: AI is artificial intelligence" in result
|
||||
assert "ANSWER: AI stands for Artificial Intelligence" in result
|
||||
|
|
@ -0,0 +1,469 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
|
||||
GraphCompletionContextExtensionRetriever,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_edge():
|
||||
"""Create a mock edge."""
|
||||
edge = MagicMock(spec=Edge)
|
||||
return edge
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplets_inherited(mock_edge):
|
||||
"""Test that get_triplets is inherited from parent class."""
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
):
|
||||
triplets = await retriever.get_triplets("test query")
|
||||
|
||||
assert len(triplets) == 1
|
||||
assert triplets[0] == mock_edge
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_defaults():
|
||||
"""Test GraphCompletionContextExtensionRetriever initialization with defaults."""
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
assert retriever.top_k == 5
|
||||
assert retriever.user_prompt_path == "graph_context_for_question.txt"
|
||||
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_params():
|
||||
"""Test GraphCompletionContextExtensionRetriever initialization with custom parameters."""
|
||||
retriever = GraphCompletionContextExtensionRetriever(
|
||||
top_k=10,
|
||||
user_prompt_path="custom_user.txt",
|
||||
system_prompt_path="custom_system.txt",
|
||||
system_prompt="Custom prompt",
|
||||
node_type=str,
|
||||
node_name=["node1"],
|
||||
save_interaction=True,
|
||||
wide_search_top_k=200,
|
||||
triplet_distance_penalty=5.0,
|
||||
)
|
||||
|
||||
assert retriever.top_k == 10
|
||||
assert retriever.user_prompt_path == "custom_user.txt"
|
||||
assert retriever.system_prompt_path == "custom_system.txt"
|
||||
assert retriever.system_prompt == "Custom prompt"
|
||||
assert retriever.node_type is str
|
||||
assert retriever.node_name == ["node1"]
|
||||
assert retriever.save_interaction is True
|
||||
assert retriever.wide_search_top_k == 200
|
||||
assert retriever.triplet_distance_penalty == 5.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_context(mock_edge):
|
||||
"""Test get_completion retrieves context when not provided."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context_extension_rounds=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_provided_context(mock_edge):
|
||||
"""Test get_completion uses provided context."""
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"test query", context=[mock_edge], context_extension_rounds=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_context_extension_rounds(mock_edge):
|
||||
"""Test get_completion with multiple context extension rounds."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
# Create a second edge for extension rounds
|
||||
mock_edge2 = MagicMock(spec=Edge)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(
|
||||
retriever,
|
||||
"get_context",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[mock_edge], [mock_edge2]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
side_effect=["Resolved context", "Extended context"], # Different contexts
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
], # Query for extension, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context_extension_rounds=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_context_extension_stops_early(mock_edge):
|
||||
"""Test get_completion stops early when no new triplets found."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
# When get_context returns same triplets, the loop should stop early
|
||||
completion = await retriever.get_completion(
|
||||
"test query", context=[mock_edge], context_extension_rounds=4
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session(mock_edge):
|
||||
"""Test get_completion with session caching enabled."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-id"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history",
|
||||
return_value="Previous conversation",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text",
|
||||
return_value="Context summary",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
], # Extension query, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history",
|
||||
) as mock_save,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
|
||||
) as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"test query", session_id="test_session", context_extension_rounds=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
mock_save.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_save_interaction(mock_edge):
|
||||
"""Test get_completion with save_interaction enabled."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
mock_graph_engine.add_edges = AsyncMock()
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever(save_interaction=True)
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
], # Extension query, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
||||
side_effect=[
|
||||
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
||||
) as mock_add_data,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"test query", context=[mock_edge], context_extension_rounds=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
mock_add_data.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_response_model(mock_edge):
|
||||
"""Test get_completion with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
TestModel(answer="Test answer"),
|
||||
], # Extension query, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"test query", response_model=TestModel, context_extension_rounds=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert isinstance(completion[0], TestModel)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session_no_user_id(mock_edge):
|
||||
"""Test get_completion with session config but no user ID."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
side_effect=[
|
||||
"Extension query",
|
||||
"Generated answer",
|
||||
], # Extension query, then final answer
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user"
|
||||
) as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = None # No user
|
||||
|
||||
completion = await retriever.get_completion("test query", context_extension_rounds=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_zero_extension_rounds(mock_edge):
|
||||
"""Test get_completion with zero context extension rounds."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionContextExtensionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context_extension_rounds=0)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
|
@ -0,0 +1,688 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_cot_retriever import (
|
||||
GraphCompletionCotRetriever,
|
||||
_as_answer_text,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_edge():
|
||||
"""Create a mock edge."""
|
||||
edge = MagicMock(spec=Edge)
|
||||
return edge
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplets_inherited(mock_edge):
|
||||
"""Test that get_triplets is inherited from parent class."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
):
|
||||
triplets = await retriever.get_triplets("test query")
|
||||
|
||||
assert len(triplets) == 1
|
||||
assert triplets[0] == mock_edge
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_params():
|
||||
"""Test GraphCompletionCotRetriever initialization with custom parameters."""
|
||||
retriever = GraphCompletionCotRetriever(
|
||||
top_k=10,
|
||||
user_prompt_path="custom_user.txt",
|
||||
system_prompt_path="custom_system.txt",
|
||||
validation_user_prompt_path="custom_validation_user.txt",
|
||||
validation_system_prompt_path="custom_validation_system.txt",
|
||||
followup_system_prompt_path="custom_followup_system.txt",
|
||||
followup_user_prompt_path="custom_followup_user.txt",
|
||||
)
|
||||
|
||||
assert retriever.top_k == 10
|
||||
assert retriever.user_prompt_path == "custom_user.txt"
|
||||
assert retriever.system_prompt_path == "custom_system.txt"
|
||||
assert retriever.validation_user_prompt_path == "custom_validation_user.txt"
|
||||
assert retriever.validation_system_prompt_path == "custom_validation_system.txt"
|
||||
assert retriever.followup_system_prompt_path == "custom_followup_system.txt"
|
||||
assert retriever.followup_user_prompt_path == "custom_followup_user.txt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_defaults():
|
||||
"""Test GraphCompletionCotRetriever initialization with defaults."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt"
|
||||
assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt"
|
||||
assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt"
|
||||
assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cot_completion_round_zero_with_context(mock_edge):
|
||||
"""Test _run_cot_completion round 0 with provided context."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||
return_value="Rendered prompt",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||
return_value="System prompt",
|
||||
),
|
||||
patch.object(
|
||||
LLMGateway,
|
||||
"acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=["validation_result", "followup_question"],
|
||||
),
|
||||
):
|
||||
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||
query="test query",
|
||||
context=[mock_edge],
|
||||
max_iter=1,
|
||||
)
|
||||
|
||||
assert completion == "Generated answer"
|
||||
assert context_text == "Resolved context"
|
||||
assert len(triplets) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cot_completion_round_zero_without_context(mock_edge):
|
||||
"""Test _run_cot_completion round 0 without provided context."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
):
|
||||
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||
query="test query",
|
||||
context=None,
|
||||
max_iter=1,
|
||||
)
|
||||
|
||||
assert completion == "Generated answer"
|
||||
assert context_text == "Resolved context"
|
||||
assert len(triplets) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cot_completion_multiple_rounds(mock_edge):
|
||||
"""Test _run_cot_completion with multiple rounds."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
mock_edge2 = MagicMock(spec=Edge)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch.object(
|
||||
retriever,
|
||||
"get_context",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[mock_edge], [mock_edge2]],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||
return_value="Rendered prompt",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||
return_value="System prompt",
|
||||
),
|
||||
patch.object(
|
||||
LLMGateway,
|
||||
"acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[
|
||||
"validation_result",
|
||||
"followup_question",
|
||||
"validation_result2",
|
||||
"followup_question2",
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
):
|
||||
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||
query="test query",
|
||||
context=[mock_edge],
|
||||
max_iter=2,
|
||||
)
|
||||
|
||||
assert completion == "Generated answer"
|
||||
assert context_text == "Resolved context"
|
||||
assert len(triplets) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cot_completion_with_conversation_history(mock_edge):
|
||||
"""Test _run_cot_completion with conversation history."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
) as mock_generate,
|
||||
):
|
||||
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||
query="test query",
|
||||
context=[mock_edge],
|
||||
conversation_history="Previous conversation",
|
||||
max_iter=1,
|
||||
)
|
||||
|
||||
assert completion == "Generated answer"
|
||||
call_kwargs = mock_generate.call_args[1]
|
||||
assert call_kwargs.get("conversation_history") == "Previous conversation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cot_completion_with_response_model(mock_edge):
|
||||
"""Test _run_cot_completion with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value=TestModel(answer="Test answer"),
|
||||
),
|
||||
):
|
||||
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||
query="test query",
|
||||
context=[mock_edge],
|
||||
response_model=TestModel,
|
||||
max_iter=1,
|
||||
)
|
||||
|
||||
assert isinstance(completion, TestModel)
|
||||
assert completion.answer == "Test answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cot_completion_empty_conversation_history(mock_edge):
|
||||
"""Test _run_cot_completion with empty conversation history."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
) as mock_generate,
|
||||
):
|
||||
completion, context_text, triplets = await retriever._run_cot_completion(
|
||||
query="test query",
|
||||
context=[mock_edge],
|
||||
conversation_history="",
|
||||
max_iter=1,
|
||||
)
|
||||
|
||||
assert completion == "Generated answer"
|
||||
# Verify conversation_history was passed as None when empty
|
||||
call_kwargs = mock_generate.call_args[1]
|
||||
assert call_kwargs.get("conversation_history") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_context(mock_edge):
|
||||
"""Test get_completion retrieves context when not provided."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||
return_value="Rendered prompt",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||
return_value="System prompt",
|
||||
),
|
||||
patch.object(
|
||||
LLMGateway,
|
||||
"acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=["validation_result", "followup_question"],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", max_iter=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_provided_context(mock_edge):
|
||||
"""Test get_completion uses provided context."""
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session(mock_edge):
|
||||
"""Test get_completion with session caching enabled."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-id"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history",
|
||||
return_value="Previous conversation",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text",
|
||||
return_value="Context summary",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history",
|
||||
) as mock_save,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
|
||||
) as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"test query", session_id="test_session", max_iter=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
mock_save.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_save_interaction(mock_edge):
|
||||
"""Test get_completion with save_interaction enabled."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
mock_graph_engine.add_edges = AsyncMock()
|
||||
|
||||
retriever = GraphCompletionCotRetriever(save_interaction=True)
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||
return_value="Rendered prompt",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||
return_value="System prompt",
|
||||
),
|
||||
patch.object(
|
||||
LLMGateway,
|
||||
"acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=["validation_result", "followup_question"],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
||||
side_effect=[
|
||||
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
||||
) as mock_add_data,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
# Pass context so save_interaction condition is met
|
||||
completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
mock_add_data.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_response_model(mock_edge):
|
||||
"""Test get_completion with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value=TestModel(answer="Test answer"),
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"test query", response_model=TestModel, max_iter=1
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert isinstance(completion[0], TestModel)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session_no_user_id(mock_edge):
|
||||
"""Test get_completion with session config but no user ID."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionCotRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.session_user"
|
||||
) as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = None # No user
|
||||
|
||||
completion = await retriever.get_completion("test query", max_iter=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_save_interaction_no_context(mock_edge):
|
||||
"""Test get_completion with save_interaction but no context provided."""
|
||||
retriever = GraphCompletionCotRetriever(save_interaction=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt",
|
||||
return_value="Rendered prompt",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt",
|
||||
return_value="System prompt",
|
||||
),
|
||||
patch.object(
|
||||
LLMGateway,
|
||||
"acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=["validation_result", "followup_question"],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context=None, max_iter=1)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_as_answer_text_with_typeerror():
|
||||
"""Test _as_answer_text handles TypeError when json.dumps fails."""
|
||||
non_serializable = {1, 2, 3}
|
||||
|
||||
result = _as_answer_text(non_serializable)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == str(non_serializable)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_as_answer_text_with_string():
|
||||
"""Test _as_answer_text with string input."""
|
||||
result = _as_answer_text("test string")
|
||||
assert result == "test string"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_as_answer_text_with_dict():
|
||||
"""Test _as_answer_text with dictionary input."""
|
||||
test_dict = {"key": "value", "number": 42}
|
||||
result = _as_answer_text(test_dict)
|
||||
assert isinstance(result, str)
|
||||
assert "key" in result
|
||||
assert "value" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_as_answer_text_with_basemodel():
|
||||
"""Test _as_answer_text with Pydantic BaseModel input."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
test_model = TestModel(answer="test answer")
|
||||
result = _as_answer_text(test_model)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "[Structured Response]" in result
|
||||
assert "test answer" in result
|
||||
|
|
@ -0,0 +1,648 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_edge():
|
||||
"""Create a mock edge."""
|
||||
edge = MagicMock(spec=Edge)
|
||||
return edge
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplets_success(mock_edge):
|
||||
"""Test successful retrieval of triplets."""
|
||||
retriever = GraphCompletionRetriever(top_k=5)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
) as mock_search:
|
||||
triplets = await retriever.get_triplets("test query")
|
||||
|
||||
assert len(triplets) == 1
|
||||
assert triplets[0] == mock_edge
|
||||
mock_search.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplets_empty_results():
|
||||
"""Test that empty list is returned when no triplets are found."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[],
|
||||
):
|
||||
triplets = await retriever.get_triplets("test query")
|
||||
|
||||
assert triplets == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_triplets_top_k_parameter():
|
||||
"""Test that top_k parameter is passed to brute_force_triplet_search."""
|
||||
retriever = GraphCompletionRetriever(top_k=10)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[],
|
||||
) as mock_search:
|
||||
await retriever.get_triplets("test query")
|
||||
|
||||
call_kwargs = mock_search.call_args[1]
|
||||
assert call_kwargs["top_k"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_success(mock_edge):
|
||||
"""Test successful retrieval of context."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert isinstance(context, list)
|
||||
assert len(context) == 1
|
||||
assert context[0] == mock_edge
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_results():
|
||||
"""Test that empty list is returned when no context is found."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_graph():
|
||||
"""Test that empty list is returned when graph is empty."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=True)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_edges_to_text(mock_edge):
|
||||
"""Test resolve_edges_to_text method."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved text",
|
||||
) as mock_resolve:
|
||||
result = await retriever.resolve_edges_to_text([mock_edge])
|
||||
|
||||
assert result == "Resolved text"
|
||||
mock_resolve.assert_awaited_once_with([mock_edge])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_defaults():
|
||||
"""Test GraphCompletionRetriever initialization with defaults."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
assert retriever.top_k == 5
|
||||
assert retriever.user_prompt_path == "graph_context_for_question.txt"
|
||||
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
||||
assert retriever.node_type is None
|
||||
assert retriever.node_name is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_params():
|
||||
"""Test GraphCompletionRetriever initialization with custom parameters."""
|
||||
retriever = GraphCompletionRetriever(
|
||||
top_k=10,
|
||||
user_prompt_path="custom_user.txt",
|
||||
system_prompt_path="custom_system.txt",
|
||||
system_prompt="Custom prompt",
|
||||
node_type=str,
|
||||
node_name=["node1"],
|
||||
save_interaction=True,
|
||||
wide_search_top_k=200,
|
||||
triplet_distance_penalty=5.0,
|
||||
)
|
||||
|
||||
assert retriever.top_k == 10
|
||||
assert retriever.user_prompt_path == "custom_user.txt"
|
||||
assert retriever.system_prompt_path == "custom_system.txt"
|
||||
assert retriever.system_prompt == "Custom prompt"
|
||||
assert retriever.node_type is str
|
||||
assert retriever.node_name == ["node1"]
|
||||
assert retriever.save_interaction is True
|
||||
assert retriever.wide_search_top_k == 200
|
||||
assert retriever.triplet_distance_penalty == 5.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_none_top_k():
|
||||
"""Test GraphCompletionRetriever initialization with None top_k."""
|
||||
retriever = GraphCompletionRetriever(top_k=None)
|
||||
|
||||
assert retriever.top_k == 5 # None defaults to 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_retrieved_objects_to_context(mock_edge):
|
||||
"""Test convert_retrieved_objects_to_context method."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved text",
|
||||
) as mock_resolve:
|
||||
result = await retriever.convert_retrieved_objects_to_context([mock_edge])
|
||||
|
||||
assert result == "Resolved text"
|
||||
mock_resolve.assert_awaited_once_with([mock_edge])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_context(mock_edge):
|
||||
"""Test get_completion retrieves context when not provided."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_provided_context(mock_edge):
|
||||
"""Test get_completion uses provided context."""
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context=[mock_edge])
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session(mock_edge):
|
||||
"""Test get_completion with session caching enabled."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-id"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_conversation_history",
|
||||
return_value="Previous conversation",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.summarize_text",
|
||||
return_value="Context summary",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.save_conversation_history",
|
||||
) as mock_save,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.session_user"
|
||||
) as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
completion = await retriever.get_completion("test query", session_id="test_session")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
mock_save.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_response_model(mock_edge):
|
||||
"""Test get_completion with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value=TestModel(answer="Test answer"),
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", response_model=TestModel)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert isinstance(completion[0], TestModel)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_empty_context(mock_edge):
|
||||
"""Test get_completion with empty context."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_qa(mock_edge):
|
||||
"""Test save_qa method."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.add_edges = AsyncMock()
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
||||
side_effect=["uuid1", "uuid2"],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
||||
) as mock_add_data,
|
||||
):
|
||||
await retriever.save_qa(
|
||||
question="Test question",
|
||||
answer="Test answer",
|
||||
context="Test context",
|
||||
triplets=[mock_edge],
|
||||
)
|
||||
|
||||
mock_add_data.assert_awaited_once()
|
||||
mock_graph_engine.add_edges.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_qa_no_triplet_ids(mock_edge):
|
||||
"""Test save_qa when triplets have no extractable IDs."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.add_edges = AsyncMock()
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
||||
) as mock_add_data,
|
||||
):
|
||||
await retriever.save_qa(
|
||||
question="Test question",
|
||||
answer="Test answer",
|
||||
context="Test context",
|
||||
triplets=[mock_edge],
|
||||
)
|
||||
|
||||
mock_add_data.assert_awaited_once()
|
||||
mock_graph_engine.add_edges.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_qa_empty_triplets():
|
||||
"""Test save_qa with empty triplets list."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.add_edges = AsyncMock()
|
||||
|
||||
retriever = GraphCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
||||
) as mock_add_data,
|
||||
):
|
||||
await retriever.save_qa(
|
||||
question="Test question",
|
||||
answer="Test answer",
|
||||
context="Test context",
|
||||
triplets=[],
|
||||
)
|
||||
|
||||
mock_add_data.assert_awaited_once()
|
||||
mock_graph_engine.add_edges.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_save_interaction_no_completion(mock_edge):
|
||||
"""Test get_completion with save_interaction but no completion."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever(save_interaction=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value=None, # No completion
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_save_interaction_no_context(mock_edge):
|
||||
"""Test get_completion with save_interaction but no context provided."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever(save_interaction=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context=None)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge):
|
||||
"""Test get_completion with save_interaction when all conditions are met (line 216)."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.is_empty = AsyncMock(return_value=False)
|
||||
|
||||
retriever = GraphCompletionRetriever(save_interaction=True)
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search",
|
||||
return_value=[mock_edge],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text",
|
||||
return_value="Resolved context",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node",
|
||||
side_effect=[
|
||||
UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
UUID("550e8400-e29b-41d4-a716-446655440001"),
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.add_data_points",
|
||||
) as mock_add_data,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.CacheConfig"
|
||||
) as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context=[mock_edge])
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
mock_add_data.assert_awaited_once()
|
||||
|
|
@ -0,0 +1,321 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_success(mock_vector_engine):
|
||||
"""Test successful retrieval of context."""
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.payload = {"text": "Steve Rodger"}
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.payload = {"text": "Mike Broski"}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
||||
|
||||
retriever = CompletionRetriever(top_k=2)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "Steve Rodger\nMike Broski"
|
||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_collection_not_found_error(mock_vector_engine):
|
||||
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
||||
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(NoDataError, match="No data found"):
|
||||
await retriever.get_context("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_results(mock_vector_engine):
|
||||
"""Test that empty string is returned when no chunks are found."""
|
||||
mock_vector_engine.search.return_value = []
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_top_k_limit(mock_vector_engine):
|
||||
"""Test that top_k parameter limits the number of results."""
|
||||
mock_results = [MagicMock() for _ in range(2)]
|
||||
for i, result in enumerate(mock_results):
|
||||
result.payload = {"text": f"Chunk {i}"}
|
||||
|
||||
mock_vector_engine.search.return_value = mock_results
|
||||
|
||||
retriever = CompletionRetriever(top_k=2)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "Chunk 0\nChunk 1"
|
||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_single_chunk(mock_vector_engine):
|
||||
"""Test get_context with single chunk result."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Single chunk text"}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "Single chunk text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_session(mock_vector_engine):
|
||||
"""Test get_completion without session caching."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Chunk text"}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_provided_context(mock_vector_engine):
|
||||
"""Test get_completion with provided context."""
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context="Provided context")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session(mock_vector_engine):
|
||||
"""Test get_completion with session caching enabled."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Chunk text"}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-id"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_conversation_history",
|
||||
return_value="Previous conversation",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.summarize_text",
|
||||
return_value="Context summary",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.save_conversation_history",
|
||||
) as mock_save,
|
||||
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
|
||||
patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
completion = await retriever.get_completion("test query", session_id="test_session")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
mock_save.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session_no_user_id(mock_vector_engine):
|
||||
"""Test get_completion with session config but no user ID."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Chunk text"}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
|
||||
patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = None # No user
|
||||
|
||||
completion = await retriever.get_completion("test query")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_response_model(mock_vector_engine):
|
||||
"""Test get_completion with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Chunk text"}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.completion_retriever.generate_completion",
|
||||
return_value=TestModel(answer="Test answer"),
|
||||
),
|
||||
patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", response_model=TestModel)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert isinstance(completion[0], TestModel)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_defaults():
|
||||
"""Test CompletionRetriever initialization with defaults."""
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
assert retriever.user_prompt_path == "context_for_question.txt"
|
||||
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
||||
assert retriever.top_k == 1
|
||||
assert retriever.system_prompt is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_params():
|
||||
"""Test CompletionRetriever initialization with custom parameters."""
|
||||
retriever = CompletionRetriever(
|
||||
user_prompt_path="custom_user.txt",
|
||||
system_prompt_path="custom_system.txt",
|
||||
system_prompt="Custom prompt",
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert retriever.user_prompt_path == "custom_user.txt"
|
||||
assert retriever.system_prompt_path == "custom_system.txt"
|
||||
assert retriever.system_prompt == "Custom prompt"
|
||||
assert retriever.top_k == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_missing_text_key(mock_vector_engine):
|
||||
"""Test get_context handles missing text key in payload."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"other_key": "value"}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = CompletionRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.completion_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(KeyError):
|
||||
await retriever.get_context("test query")
|
||||
193
cognee/tests/unit/modules/retrieval/summaries_retriever_test.py
Normal file
193
cognee/tests/unit/modules/retrieval/summaries_retriever_test.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_success(mock_vector_engine):
|
||||
"""Test successful retrieval of summary context."""
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.payload = {"text": "S.R.", "made_from": "chunk1"}
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.payload = {"text": "M.B.", "made_from": "chunk2"}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
||||
|
||||
retriever = SummariesRetriever(top_k=5)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert len(context) == 2
|
||||
assert context[0]["text"] == "S.R."
|
||||
assert context[1]["text"] == "M.B."
|
||||
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_collection_not_found_error(mock_vector_engine):
|
||||
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
||||
mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found")
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(NoDataError, match="No data found"):
|
||||
await retriever.get_context("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_results(mock_vector_engine):
|
||||
"""Test that empty list is returned when no summaries are found."""
|
||||
mock_vector_engine.search.return_value = []
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_top_k_limit(mock_vector_engine):
|
||||
"""Test that top_k parameter limits the number of results."""
|
||||
mock_results = [MagicMock() for _ in range(3)]
|
||||
for i, result in enumerate(mock_results):
|
||||
result.payload = {"text": f"Summary {i}"}
|
||||
|
||||
mock_vector_engine.search.return_value = mock_results
|
||||
|
||||
retriever = SummariesRetriever(top_k=3)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert len(context) == 3
|
||||
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_context(mock_vector_engine):
|
||||
"""Test get_completion returns provided context."""
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
provided_context = [{"text": "S.R."}, {"text": "M.B."}]
|
||||
completion = await retriever.get_completion("test query", context=provided_context)
|
||||
|
||||
assert completion == provided_context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_context(mock_vector_engine):
|
||||
"""Test get_completion retrieves context when not provided."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "S.R."}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
completion = await retriever.get_completion("test query")
|
||||
|
||||
assert len(completion) == 1
|
||||
assert completion[0]["text"] == "S.R."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_defaults():
|
||||
"""Test SummariesRetriever initialization with defaults."""
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
assert retriever.top_k == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_top_k():
|
||||
"""Test SummariesRetriever initialization with custom top_k."""
|
||||
retriever = SummariesRetriever(top_k=10)
|
||||
|
||||
assert retriever.top_k == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_payload(mock_vector_engine):
|
||||
"""Test get_context handles empty payload."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert len(context) == 1
|
||||
assert context[0] == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session_id(mock_vector_engine):
|
||||
"""Test get_completion with session_id parameter."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "S.R."}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
completion = await retriever.get_completion("test query", session_id="test_session")
|
||||
|
||||
assert len(completion) == 1
|
||||
assert completion[0]["text"] == "S.R."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_kwargs(mock_vector_engine):
|
||||
"""Test get_completion accepts additional kwargs."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "S.R."}
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = SummariesRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.summaries_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
completion = await retriever.get_completion("test query", extra_param="value")
|
||||
|
||||
assert len(completion) == 1
|
||||
705
cognee/tests/unit/modules/retrieval/temporal_retriever_test.py
Normal file
705
cognee/tests/unit/modules/retrieval/temporal_retriever_test.py
Normal file
|
|
@ -0,0 +1,705 @@
|
|||
from types import SimpleNamespace
|
||||
import pytest
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from datetime import datetime
|
||||
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
|
||||
|
||||
# Test TemporalRetriever initialization defaults and overrides
|
||||
def test_init_defaults_and_overrides():
|
||||
tr = TemporalRetriever()
|
||||
assert tr.top_k == 5
|
||||
assert tr.user_prompt_path == "graph_context_for_question.txt"
|
||||
assert tr.system_prompt_path == "answer_simple_question.txt"
|
||||
assert tr.time_extraction_prompt_path == "extract_query_time.txt"
|
||||
|
||||
tr2 = TemporalRetriever(
|
||||
top_k=3,
|
||||
user_prompt_path="u.txt",
|
||||
system_prompt_path="s.txt",
|
||||
time_extraction_prompt_path="t.txt",
|
||||
)
|
||||
assert tr2.top_k == 3
|
||||
assert tr2.user_prompt_path == "u.txt"
|
||||
assert tr2.system_prompt_path == "s.txt"
|
||||
assert tr2.time_extraction_prompt_path == "t.txt"
|
||||
|
||||
|
||||
# Test descriptions_to_string with basic and empty results
|
||||
def test_descriptions_to_string_basic_and_empty():
|
||||
tr = TemporalRetriever()
|
||||
|
||||
results = [
|
||||
{"description": " First "},
|
||||
{"nope": "no description"},
|
||||
{"description": "Second"},
|
||||
{"description": ""},
|
||||
{"description": " Third line "},
|
||||
]
|
||||
|
||||
s = tr.descriptions_to_string(results)
|
||||
assert s == "First\n#####################\nSecond\n#####################\nThird line"
|
||||
|
||||
assert tr.descriptions_to_string([]) == ""
|
||||
|
||||
|
||||
# Test filter_top_k_events sorts and limits correctly
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_sorts_and_limits():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
|
||||
relevant_events = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "E1"},
|
||||
{"id": "e2", "description": "E2"},
|
||||
{"id": "e3", "description": "E3 - not in vector results"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "e2"}, score=0.10),
|
||||
SimpleNamespace(payload={"id": "e1"}, score=0.20),
|
||||
]
|
||||
|
||||
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
|
||||
assert [e["id"] for e in top] == ["e2", "e1"]
|
||||
assert all("score" in e for e in top)
|
||||
assert top[0]["score"] == 0.10
|
||||
assert top[1]["score"] == 0.20
|
||||
|
||||
|
||||
# Test filter_top_k_events handles unknown ids as infinite scores
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
|
||||
relevant_events = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "known1", "description": "Known 1"},
|
||||
{"id": "unknown", "description": "Unknown"},
|
||||
{"id": "known2", "description": "Known 2"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "known2"}, score=0.05),
|
||||
SimpleNamespace(payload={"id": "known1"}, score=0.50),
|
||||
]
|
||||
|
||||
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
assert [e["id"] for e in top] == ["known2", "known1"]
|
||||
assert all(e["score"] != float("inf") for e in top)
|
||||
|
||||
|
||||
# Test descriptions_to_string with unicode and newlines
|
||||
def test_descriptions_to_string_unicode_and_newlines():
|
||||
tr = TemporalRetriever()
|
||||
results = [
|
||||
{"description": "Line A\nwith newline"},
|
||||
{"description": "This is a description"},
|
||||
]
|
||||
s = tr.descriptions_to_string(results)
|
||||
assert "Line A\nwith newline" in s
|
||||
assert "This is a description" in s
|
||||
assert s.count("#####################") == 1
|
||||
|
||||
|
||||
# Test filter_top_k_events when top_k is larger than available events
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
|
||||
tr = TemporalRetriever(top_k=10)
|
||||
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "a"}, score=0.1),
|
||||
SimpleNamespace(payload={"id": "b"}, score=0.2),
|
||||
]
|
||||
out = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
assert [e["id"] for e in out] == ["a", "b"]
|
||||
|
||||
|
||||
# Test filter_top_k_events when scored_results is empty
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_handles_empty_scored_results():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}]
|
||||
scored_results = []
|
||||
out = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
assert [e["id"] for e in out] == ["x", "y"]
|
||||
assert all(e["score"] == float("inf") for e in out)
|
||||
|
||||
|
||||
# Test filter_top_k_events error handling for missing structure
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_top_k_events_error_handling():
|
||||
tr = TemporalRetriever(top_k=2)
|
||||
with pytest.raises((KeyError, TypeError)):
|
||||
await tr.filter_top_k_events([{}], [])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_engine():
|
||||
"""Create a mock graph engine."""
|
||||
engine = AsyncMock()
|
||||
engine.collect_time_ids = AsyncMock()
|
||||
engine.collect_events = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.embedding_engine = AsyncMock()
|
||||
engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_context when time range is extracted from query."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
{"id": "e2", "description": "Event 2"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
|
||||
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
|
||||
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
):
|
||||
context = await retriever.get_context("What happened in 2024?")
|
||||
|
||||
assert isinstance(context, str)
|
||||
assert len(context) > 0
|
||||
assert "Event" in context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine):
|
||||
"""Test get_context falls back to triplets when no time is extracted."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(
|
||||
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
|
||||
) as mock_get_triplets,
|
||||
patch.object(
|
||||
retriever, "resolve_edges_to_text", return_value="triplet text"
|
||||
) as mock_resolve,
|
||||
):
|
||||
|
||||
async def mock_extract_time(query):
|
||||
return None, None
|
||||
|
||||
retriever.extract_time_from_query = mock_extract_time
|
||||
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "triplet text"
|
||||
mock_get_triplets.assert_awaited_once_with("test query")
|
||||
mock_resolve.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_no_events_found(mock_graph_engine):
|
||||
"""Test get_context falls back to triplets when no events are found."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch.object(
|
||||
retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}]
|
||||
) as mock_get_triplets,
|
||||
patch.object(
|
||||
retriever, "resolve_edges_to_text", return_value="triplet text"
|
||||
) as mock_resolve,
|
||||
):
|
||||
|
||||
async def mock_extract_time(query):
|
||||
return "2024-01-01", "2024-12-31"
|
||||
|
||||
retriever.extract_time_from_query = mock_extract_time
|
||||
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "triplet text"
|
||||
mock_get_triplets.assert_awaited_once_with("test query")
|
||||
mock_resolve.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_context with only time_from."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
):
|
||||
context = await retriever.get_context("What happened after 2024?")
|
||||
|
||||
assert isinstance(context, str)
|
||||
assert "Event 1" in context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_context with only time_to."""
|
||||
retriever = TemporalRetriever(top_k=5)
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
):
|
||||
context = await retriever.get_context("What happened before 2024?")
|
||||
|
||||
assert isinstance(context, str)
|
||||
assert "Event 1" in context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_completion retrieves context when not provided."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("What happened in 2024?")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_provided_context():
|
||||
"""Test get_completion uses provided context."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context="Provided context")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_completion with session caching enabled."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-id"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_conversation_history",
|
||||
return_value="Previous conversation",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.summarize_text",
|
||||
return_value="Context summary",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.save_conversation_history",
|
||||
) as mock_save,
|
||||
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
||||
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"What happened in 2024?", session_id="test_session"
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
mock_save.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_completion with session config but no user ID."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
||||
patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = None # No user
|
||||
|
||||
completion = await retriever.get_completion("What happened in 2024?")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_context_retrieved_but_empty(mock_graph_engine):
|
||||
"""Test get_completion when get_context returns empty string."""
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
) as mock_get_vector,
|
||||
patch.object(retriever, "filter_top_k_events", return_value=[]),
|
||||
):
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
mock_get_vector.return_value = mock_vector_engine
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": ""},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises((UnboundLocalError, NameError)):
|
||||
await retriever.get_completion("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine):
|
||||
"""Test get_completion with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
retriever = TemporalRetriever()
|
||||
|
||||
mock_graph_engine.collect_time_ids.return_value = ["e1"]
|
||||
mock_graph_engine.collect_events.return_value = [
|
||||
{
|
||||
"events": [
|
||||
{"id": "e1", "description": "Event 1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31")
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.generate_completion",
|
||||
return_value=TestModel(answer="Test answer"),
|
||||
),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion(
|
||||
"What happened in 2024?", response_model=TestModel
|
||||
)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert isinstance(completion[0], TestModel)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_time_from_query_relative_path():
|
||||
"""Test extract_time_from_query with relative prompt path."""
|
||||
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
|
||||
|
||||
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
|
||||
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
|
||||
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
|
||||
|
||||
with (
|
||||
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.render_prompt",
|
||||
return_value="System prompt",
|
||||
),
|
||||
patch.object(
|
||||
LLMGateway,
|
||||
"acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_interval,
|
||||
),
|
||||
):
|
||||
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
||||
|
||||
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
|
||||
|
||||
assert time_from == mock_timestamp_from
|
||||
assert time_to == mock_timestamp_to
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_time_from_query_absolute_path():
|
||||
"""Test extract_time_from_query with absolute prompt path."""
|
||||
retriever = TemporalRetriever(
|
||||
time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt"
|
||||
)
|
||||
|
||||
mock_timestamp_from = Timestamp(year=2024, month=1, day=1)
|
||||
mock_timestamp_to = Timestamp(year=2024, month=12, day=31)
|
||||
mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to)
|
||||
|
||||
with (
|
||||
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.os.path.dirname",
|
||||
return_value="/absolute/path/to",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.os.path.basename",
|
||||
return_value="extract_query_time.txt",
|
||||
),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.render_prompt",
|
||||
return_value="System prompt",
|
||||
),
|
||||
patch.object(
|
||||
LLMGateway,
|
||||
"acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_interval,
|
||||
),
|
||||
):
|
||||
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
||||
|
||||
time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?")
|
||||
|
||||
assert time_from == mock_timestamp_from
|
||||
assert time_to == mock_timestamp_to
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_time_from_query_with_none_values():
|
||||
"""Test extract_time_from_query when interval has None values."""
|
||||
retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt")
|
||||
|
||||
mock_interval = QueryInterval(starts_at=None, ends_at=None)
|
||||
|
||||
with (
|
||||
patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False),
|
||||
patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime,
|
||||
patch(
|
||||
"cognee.modules.retrieval.temporal_retriever.render_prompt",
|
||||
return_value="System prompt",
|
||||
),
|
||||
patch.object(
|
||||
LLMGateway,
|
||||
"acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_interval,
|
||||
),
|
||||
):
|
||||
mock_datetime.now.return_value.strftime.return_value = "11-12-2024"
|
||||
|
||||
time_from, time_to = await retriever.extract_time_from_query("What happened?")
|
||||
|
||||
assert time_from is None
|
||||
assert time_to is None
|
||||
|
|
@ -0,0 +1,817 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||
brute_force_triplet_search,
|
||||
get_memory_fragment,
|
||||
format_triplets,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
"""Mock class for vector search results."""
|
||||
|
||||
def __init__(self, id, score, payload=None):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.payload = payload or {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_empty_query():
|
||||
"""Test that empty query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
await brute_force_triplet_search(query="")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_none_query():
|
||||
"""Test that None query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
await brute_force_triplet_search(query=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_negative_top_k():
|
||||
"""Test that negative top_k raises ValueError."""
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
await brute_force_triplet_search(query="test query", top_k=-1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_zero_top_k():
|
||||
"""Test that zero top_k raises ValueError."""
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
await brute_force_triplet_search(query="test query", top_k=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_limit_global_search():
|
||||
"""Test that wide_search_limit is applied for global search (node_name=None)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=None, # Global search
|
||||
wide_search_top_k=75,
|
||||
)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] == 75
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
|
||||
"""Test that wide_search_limit is None for filtered search (node_name provided)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=["Node1"],
|
||||
wide_search_top_k=50,
|
||||
)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_default():
|
||||
"""Test that wide_search_top_k defaults to 100."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_default_collections():
|
||||
"""Test that default collections are used when none provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test")
|
||||
|
||||
expected_collections = [
|
||||
"Entity_name",
|
||||
"TextSummary_text",
|
||||
"EntityType_name",
|
||||
"DocumentChunk_text",
|
||||
"EdgeType_relationship_name",
|
||||
]
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == expected_collections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_custom_collections():
|
||||
"""Test that custom collections are used when provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
custom_collections = ["CustomCol1", "CustomCol2"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=custom_collections)
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_always_includes_edge_collection():
|
||||
"""Test that EdgeType_relationship_name is always searched even when not in collections."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
collections_without_edge = ["Entity_name", "TextSummary_text"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=collections_without_edge)
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert "EdgeType_relationship_name" in call_collections
|
||||
assert set(call_collections) == set(collections_without_edge) | {
|
||||
"EdgeType_relationship_name"
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_all_collections_empty():
|
||||
"""Test that empty list is returned when all collections return no results."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
results = await brute_force_triplet_search(query="test")
|
||||
assert results == []
|
||||
|
||||
|
||||
# Tests for query embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_embeds_query():
|
||||
"""Test that query is embedded before searching."""
|
||||
query_text = "test query"
|
||||
expected_vector = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query=query_text)
|
||||
|
||||
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["query_vector"] == expected_vector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_extracts_node_ids_global_search():
|
||||
"""Test that node IDs are extracted from search results for global search."""
|
||||
scored_results = [
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
MockScoredResult("node3", 0.92),
|
||||
]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=scored_results)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_reuses_provided_fragment():
|
||||
"""Test that provided memory fragment is reused instead of creating new one."""
|
||||
provided_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
memory_fragment=provided_fragment,
|
||||
node_name=["node"],
|
||||
)
|
||||
|
||||
mock_get_fragment.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
|
||||
"""Test that memory fragment is created when not provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=["node"])
|
||||
|
||||
mock_get_fragment.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
|
||||
"""Test that custom top_k is passed to importance calculation."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
custom_top_k = 15
|
||||
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
|
||||
"""Test that get_memory_fragment returns empty graph when entity not found (line 85)."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
|
||||
# Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called
|
||||
mock_fragment = MagicMock(spec=CogneeGraph)
|
||||
mock_fragment.project_graph_from_db = AsyncMock(
|
||||
side_effect=EntityNotFoundError("Entity not found")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
result = await get_memory_fragment()
|
||||
|
||||
# Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85)
|
||||
assert result == mock_fragment
|
||||
mock_fragment.project_graph_from_db.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_error():
|
||||
"""Test that get_memory_fragment returns empty graph on generic error."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
):
|
||||
fragment = await get_memory_fragment()
|
||||
|
||||
assert isinstance(fragment, CogneeGraph)
|
||||
assert len(fragment.nodes) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_deduplicates_node_ids():
|
||||
"""Test that duplicate node IDs across collections are deduplicated."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
]
|
||||
elif collection_name == "TextSummary_text":
|
||||
return [
|
||||
MockScoredResult("node1", 0.90),
|
||||
MockScoredResult("node3", 0.92),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
||||
assert len(call_kwargs["relevant_ids_to_filter"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_excludes_edge_collection():
|
||||
"""Test that EdgeType_relationship_name collection is excluded from ID extraction."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [MockScoredResult("node1", 0.95)]
|
||||
elif collection_name == "EdgeType_relationship_name":
|
||||
return [MockScoredResult("edge1", 0.88)]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=None,
|
||||
collections=["Entity_name", "EdgeType_relationship_name"],
|
||||
)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_skips_nodes_without_ids():
|
||||
"""Test that nodes without ID attribute are skipped."""
|
||||
|
||||
class ScoredResultNoId:
|
||||
"""Mock result without id attribute."""
|
||||
|
||||
def __init__(self, score):
|
||||
self.score = score
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
MockScoredResult("node1", 0.95),
|
||||
ScoredResultNoId(0.90),
|
||||
MockScoredResult("node2", 0.87),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_handles_tuple_results():
|
||||
"""Test that both list and tuple results are handled correctly."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return (
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_mixed_empty_collections():
|
||||
"""Test ID extraction with mixed empty and non-empty collections."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [MockScoredResult("node1", 0.95)]
|
||||
elif collection_name == "TextSummary_text":
|
||||
return []
|
||||
elif collection_name == "EntityType_name":
|
||||
return [MockScoredResult("node2", 0.92)]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
def test_format_triplets():
|
||||
"""Test format_triplets function."""
|
||||
mock_edge = MagicMock()
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
|
||||
mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"}
|
||||
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"}
|
||||
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"}
|
||||
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
result = format_triplets([mock_edge])
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Node1" in result
|
||||
assert "Node2" in result
|
||||
assert "relates_to" in result
|
||||
assert "connects" in result
|
||||
|
||||
|
||||
def test_format_triplets_with_none_values():
|
||||
"""Test format_triplets filters out None values."""
|
||||
mock_edge = MagicMock()
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
|
||||
mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"}
|
||||
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None}
|
||||
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None}
|
||||
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
result = format_triplets([mock_edge])
|
||||
|
||||
assert "Node1" in result
|
||||
assert "Node2" in result
|
||||
assert "relates_to" in result
|
||||
assert "None" not in result or result.count("None") == 0
|
||||
|
||||
|
||||
def test_format_triplets_with_nested_dict():
|
||||
"""Test format_triplets handles nested dict attributes (lines 23-35)."""
|
||||
mock_edge = MagicMock()
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
|
||||
mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}}
|
||||
mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}}
|
||||
mock_edge.attributes = {"relationship_name": "relates_to"}
|
||||
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
result = format_triplets([mock_edge])
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Node1" in result
|
||||
assert "Node2" in result
|
||||
assert "relates_to" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_vector_engine_init_error():
|
||||
"""Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine"
|
||||
) as mock_get_vector_engine,
|
||||
):
|
||||
mock_get_vector_engine.side_effect = Exception("Initialization error")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Initialization error"):
|
||||
await brute_force_triplet_search(query="test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_collection_not_found_error():
|
||||
"""Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = mock_embedding_engine
|
||||
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
mock_vector_engine.search = AsyncMock(
|
||||
side_effect=[
|
||||
CollectionNotFoundError("Collection not found"),
|
||||
[],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=CogneeGraph(),
|
||||
),
|
||||
):
|
||||
result = await brute_force_triplet_search(
|
||||
query="test query", collections=["missing_collection", "existing_collection"]
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_generic_exception():
|
||||
"""Test brute_force_triplet_search handles generic exceptions (lines 209-217)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = mock_embedding_engine
|
||||
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error"))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception, match="Generic error"):
|
||||
await brute_force_triplet_search(query="test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none():
|
||||
"""Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = mock_embedding_engine
|
||||
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
|
||||
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
|
||||
|
||||
mock_fragment = AsyncMock()
|
||||
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
|
||||
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
|
||||
mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(query="test query", node_name=["Node1"])
|
||||
|
||||
assert mock_get_fragment.called
|
||||
call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {}
|
||||
assert call_kwargs.get("relevant_ids_to_filter") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_collection_not_found_at_top_level():
|
||||
"""Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = mock_embedding_engine
|
||||
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
|
||||
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
|
||||
|
||||
mock_fragment = AsyncMock()
|
||||
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
|
||||
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
|
||||
mock_fragment.calculate_top_triplet_importances = AsyncMock(
|
||||
side_effect=CollectionNotFoundError("Collection not found")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
result = await brute_force_triplet_search(query="test query")
|
||||
|
||||
assert result == []
|
||||
343
cognee/tests/unit/modules/retrieval/test_completion.py
Normal file
343
cognee/tests/unit/modules/retrieval/test_completion.py
Normal file
|
|
@ -0,0 +1,343 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from typing import Type
|
||||
|
||||
|
||||
class TestGenerateCompletion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_completion_with_system_prompt(self):
|
||||
"""Test generate_completion with provided system_prompt."""
|
||||
mock_llm_response = "Generated answer"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||
return_value="User prompt text",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_llm_response,
|
||||
) as mock_llm,
|
||||
):
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
|
||||
result = await generate_completion(
|
||||
query="What is AI?",
|
||||
context="AI is artificial intelligence",
|
||||
user_prompt_path="user_prompt.txt",
|
||||
system_prompt_path="system_prompt.txt",
|
||||
system_prompt="Custom system prompt",
|
||||
)
|
||||
|
||||
assert result == mock_llm_response
|
||||
mock_llm.assert_awaited_once_with(
|
||||
text_input="User prompt text",
|
||||
system_prompt="Custom system prompt",
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_completion_without_system_prompt(self):
|
||||
"""Test generate_completion reads system_prompt from file when not provided."""
|
||||
mock_llm_response = "Generated answer"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||
return_value="User prompt text",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||
return_value="System prompt from file",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_llm_response,
|
||||
) as mock_llm,
|
||||
):
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
|
||||
result = await generate_completion(
|
||||
query="What is AI?",
|
||||
context="AI is artificial intelligence",
|
||||
user_prompt_path="user_prompt.txt",
|
||||
system_prompt_path="system_prompt.txt",
|
||||
)
|
||||
|
||||
assert result == mock_llm_response
|
||||
mock_llm.assert_awaited_once_with(
|
||||
text_input="User prompt text",
|
||||
system_prompt="System prompt from file",
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_completion_with_conversation_history(self):
|
||||
"""Test generate_completion includes conversation_history in system_prompt."""
|
||||
mock_llm_response = "Generated answer"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||
return_value="User prompt text",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||
return_value="System prompt from file",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_llm_response,
|
||||
) as mock_llm,
|
||||
):
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
|
||||
result = await generate_completion(
|
||||
query="What is AI?",
|
||||
context="AI is artificial intelligence",
|
||||
user_prompt_path="user_prompt.txt",
|
||||
system_prompt_path="system_prompt.txt",
|
||||
conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
|
||||
)
|
||||
|
||||
assert result == mock_llm_response
|
||||
expected_system_prompt = (
|
||||
"Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
|
||||
+ "\nTASK:"
|
||||
+ "System prompt from file"
|
||||
)
|
||||
mock_llm.assert_awaited_once_with(
|
||||
text_input="User prompt text",
|
||||
system_prompt=expected_system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self):
|
||||
"""Test generate_completion includes conversation_history with custom system_prompt."""
|
||||
mock_llm_response = "Generated answer"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||
return_value="User prompt text",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_llm_response,
|
||||
) as mock_llm,
|
||||
):
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
|
||||
result = await generate_completion(
|
||||
query="What is AI?",
|
||||
context="AI is artificial intelligence",
|
||||
user_prompt_path="user_prompt.txt",
|
||||
system_prompt_path="system_prompt.txt",
|
||||
system_prompt="Custom system prompt",
|
||||
conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning",
|
||||
)
|
||||
|
||||
assert result == mock_llm_response
|
||||
expected_system_prompt = (
|
||||
"Previous conversation:\nQ: What is ML?\nA: ML is machine learning"
|
||||
+ "\nTASK:"
|
||||
+ "Custom system prompt"
|
||||
)
|
||||
mock_llm.assert_awaited_once_with(
|
||||
text_input="User prompt text",
|
||||
system_prompt=expected_system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_completion_with_response_model(self):
|
||||
"""Test generate_completion with custom response_model."""
|
||||
mock_response_model = MagicMock()
|
||||
mock_llm_response = {"answer": "Generated answer"}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||
return_value="User prompt text",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||
return_value="System prompt from file",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_llm_response,
|
||||
) as mock_llm,
|
||||
):
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
|
||||
result = await generate_completion(
|
||||
query="What is AI?",
|
||||
context="AI is artificial intelligence",
|
||||
user_prompt_path="user_prompt.txt",
|
||||
system_prompt_path="system_prompt.txt",
|
||||
response_model=mock_response_model,
|
||||
)
|
||||
|
||||
assert result == mock_llm_response
|
||||
mock_llm.assert_awaited_once_with(
|
||||
text_input="User prompt text",
|
||||
system_prompt="System prompt from file",
|
||||
response_model=mock_response_model,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_completion_render_prompt_args(self):
|
||||
"""Test generate_completion passes correct args to render_prompt."""
|
||||
mock_llm_response = "Generated answer"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.render_prompt",
|
||||
return_value="User prompt text",
|
||||
) as mock_render,
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||
return_value="System prompt from file",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_llm_response,
|
||||
),
|
||||
):
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
|
||||
await generate_completion(
|
||||
query="What is AI?",
|
||||
context="AI is artificial intelligence",
|
||||
user_prompt_path="user_prompt.txt",
|
||||
system_prompt_path="system_prompt.txt",
|
||||
)
|
||||
|
||||
mock_render.assert_called_once_with(
|
||||
"user_prompt.txt",
|
||||
{"question": "What is AI?", "context": "AI is artificial intelligence"},
|
||||
)
|
||||
|
||||
|
||||
class TestSummarizeText:
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_text_with_system_prompt(self):
|
||||
"""Test summarize_text with provided system_prompt."""
|
||||
mock_llm_response = "Summary text"
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_llm_response,
|
||||
) as mock_llm:
|
||||
from cognee.modules.retrieval.utils.completion import summarize_text
|
||||
|
||||
result = await summarize_text(
|
||||
text="Long text to summarize",
|
||||
system_prompt_path="summarize_search_results.txt",
|
||||
system_prompt="Custom summary prompt",
|
||||
)
|
||||
|
||||
assert result == mock_llm_response
|
||||
mock_llm.assert_awaited_once_with(
|
||||
text_input="Long text to summarize",
|
||||
system_prompt="Custom summary prompt",
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_text_without_system_prompt(self):
|
||||
"""Test summarize_text reads system_prompt from file when not provided."""
|
||||
mock_llm_response = "Summary text"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||
return_value="System prompt from file",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_llm_response,
|
||||
) as mock_llm,
|
||||
):
|
||||
from cognee.modules.retrieval.utils.completion import summarize_text
|
||||
|
||||
result = await summarize_text(
|
||||
text="Long text to summarize",
|
||||
system_prompt_path="summarize_search_results.txt",
|
||||
)
|
||||
|
||||
assert result == mock_llm_response
|
||||
mock_llm.assert_awaited_once_with(
|
||||
text_input="Long text to summarize",
|
||||
system_prompt="System prompt from file",
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_text_default_prompt_path(self):
|
||||
"""Test summarize_text uses default prompt path when not provided."""
|
||||
mock_llm_response = "Summary text"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||
return_value="Default system prompt",
|
||||
) as mock_read,
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_llm_response,
|
||||
) as mock_llm,
|
||||
):
|
||||
from cognee.modules.retrieval.utils.completion import summarize_text
|
||||
|
||||
result = await summarize_text(text="Long text to summarize")
|
||||
|
||||
assert result == mock_llm_response
|
||||
mock_read.assert_called_once_with("summarize_search_results.txt")
|
||||
mock_llm.assert_awaited_once_with(
|
||||
text_input="Long text to summarize",
|
||||
system_prompt="Default system prompt",
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_text_custom_prompt_path(self):
|
||||
"""Test summarize_text uses custom prompt path when provided."""
|
||||
mock_llm_response = "Summary text"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.read_query_prompt",
|
||||
return_value="Custom system prompt",
|
||||
) as mock_read,
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_llm_response,
|
||||
) as mock_llm,
|
||||
):
|
||||
from cognee.modules.retrieval.utils.completion import summarize_text
|
||||
|
||||
result = await summarize_text(
|
||||
text="Long text to summarize",
|
||||
system_prompt_path="custom_prompt.txt",
|
||||
)
|
||||
|
||||
assert result == mock_llm_response
|
||||
mock_read.assert_called_once_with("custom_prompt.txt")
|
||||
mock_llm.assert_awaited_once_with(
|
||||
text_input="Long text to summarize",
|
||||
system_prompt="Custom system prompt",
|
||||
response_model=str,
|
||||
)
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_edge():
|
||||
"""Create a mock edge."""
|
||||
edge = MagicMock(spec=Edge)
|
||||
return edge
|
||||
|
||||
|
||||
class TestGraphSummaryCompletionRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_defaults(self):
|
||||
"""Test GraphSummaryCompletionRetriever initialization with defaults."""
|
||||
retriever = GraphSummaryCompletionRetriever()
|
||||
|
||||
assert retriever.summarize_prompt_path == "summarize_search_results.txt"
|
||||
assert retriever.user_prompt_path == "graph_context_for_question.txt"
|
||||
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
||||
assert retriever.top_k == 5
|
||||
assert retriever.save_interaction is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_params(self):
|
||||
"""Test GraphSummaryCompletionRetriever initialization with custom parameters."""
|
||||
retriever = GraphSummaryCompletionRetriever(
|
||||
user_prompt_path="custom_user.txt",
|
||||
system_prompt_path="custom_system.txt",
|
||||
summarize_prompt_path="custom_summarize.txt",
|
||||
system_prompt="Custom system prompt",
|
||||
top_k=10,
|
||||
save_interaction=True,
|
||||
wide_search_top_k=200,
|
||||
triplet_distance_penalty=2.5,
|
||||
)
|
||||
|
||||
assert retriever.summarize_prompt_path == "custom_summarize.txt"
|
||||
assert retriever.user_prompt_path == "custom_user.txt"
|
||||
assert retriever.system_prompt_path == "custom_system.txt"
|
||||
assert retriever.top_k == 10
|
||||
assert retriever.save_interaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_edges_to_text_calls_super_and_summarizes(self, mock_edge):
|
||||
"""Test resolve_edges_to_text calls super method and then summarizes."""
|
||||
retriever = GraphSummaryCompletionRetriever(
|
||||
summarize_prompt_path="custom_summarize.txt",
|
||||
system_prompt="Custom system prompt",
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
||||
new_callable=AsyncMock,
|
||||
return_value="Resolved edges text",
|
||||
) as mock_super_resolve,
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
||||
new_callable=AsyncMock,
|
||||
return_value="Summarized text",
|
||||
) as mock_summarize,
|
||||
):
|
||||
result = await retriever.resolve_edges_to_text([mock_edge])
|
||||
|
||||
assert result == "Summarized text"
|
||||
mock_super_resolve.assert_awaited_once_with([mock_edge])
|
||||
mock_summarize.assert_awaited_once_with(
|
||||
"Resolved edges text",
|
||||
"custom_summarize.txt",
|
||||
"Custom system prompt",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_edges_to_text_with_default_system_prompt(self, mock_edge):
|
||||
"""Test resolve_edges_to_text uses None for system_prompt when not provided."""
|
||||
retriever = GraphSummaryCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
||||
new_callable=AsyncMock,
|
||||
return_value="Resolved edges text",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
||||
new_callable=AsyncMock,
|
||||
return_value="Summarized text",
|
||||
) as mock_summarize,
|
||||
):
|
||||
await retriever.resolve_edges_to_text([mock_edge])
|
||||
|
||||
mock_summarize.assert_awaited_once_with(
|
||||
"Resolved edges text",
|
||||
"summarize_search_results.txt",
|
||||
None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_edges_to_text_with_empty_edges(self):
|
||||
"""Test resolve_edges_to_text handles empty edges list."""
|
||||
retriever = GraphSummaryCompletionRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
||||
new_callable=AsyncMock,
|
||||
return_value="",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
||||
new_callable=AsyncMock,
|
||||
return_value="Empty summary",
|
||||
) as mock_summarize,
|
||||
):
|
||||
result = await retriever.resolve_edges_to_text([])
|
||||
|
||||
assert result == "Empty summary"
|
||||
mock_summarize.assert_awaited_once_with(
|
||||
"",
|
||||
"summarize_search_results.txt",
|
||||
None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_edges_to_text_with_multiple_edges(self, mock_edge):
|
||||
"""Test resolve_edges_to_text handles multiple edges."""
|
||||
retriever = GraphSummaryCompletionRetriever()
|
||||
|
||||
mock_edge2 = MagicMock(spec=Edge)
|
||||
mock_edge3 = MagicMock(spec=Edge)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text",
|
||||
new_callable=AsyncMock,
|
||||
return_value="Multiple edges resolved text",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text",
|
||||
new_callable=AsyncMock,
|
||||
return_value="Multiple edges summarized",
|
||||
) as mock_summarize,
|
||||
):
|
||||
result = await retriever.resolve_edges_to_text([mock_edge, mock_edge2, mock_edge3])
|
||||
|
||||
assert result == "Multiple edges summarized"
|
||||
mock_summarize.assert_awaited_once_with(
|
||||
"Multiple edges resolved text",
|
||||
"summarize_search_results.txt",
|
||||
None,
|
||||
)
|
||||
312
cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py
Normal file
312
cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import UUID, NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||
from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feedback_evaluation():
|
||||
"""Create a mock feedback evaluation."""
|
||||
evaluation = MagicMock(spec=UserFeedbackEvaluation)
|
||||
evaluation.evaluation = MagicMock()
|
||||
evaluation.evaluation.value = "positive"
|
||||
evaluation.score = 4.5
|
||||
return evaluation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_engine():
|
||||
"""Create a mock graph engine."""
|
||||
engine = AsyncMock()
|
||||
engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
||||
engine.add_edges = AsyncMock()
|
||||
engine.apply_feedback_weight = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
class TestUserQAFeedback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_default(self):
|
||||
"""Test UserQAFeedback initialization with default last_k."""
|
||||
retriever = UserQAFeedback()
|
||||
assert retriever.last_k == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_last_k(self):
|
||||
"""Test UserQAFeedback initialization with custom last_k."""
|
||||
retriever = UserQAFeedback(last_k=5)
|
||||
assert retriever.last_k == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_feedback_success_with_relationships(
|
||||
self, mock_feedback_evaluation, mock_graph_engine
|
||||
):
|
||||
"""Test add_feedback successfully creates feedback with relationships."""
|
||||
interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
|
||||
interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001"))
|
||||
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(
|
||||
return_value=[interaction_id_1, interaction_id_2]
|
||||
)
|
||||
|
||||
feedback_text = "This answer was helpful"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_feedback_evaluation,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_add_data,
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_index_edges,
|
||||
):
|
||||
retriever = UserQAFeedback(last_k=2)
|
||||
result = await retriever.add_feedback(feedback_text)
|
||||
|
||||
assert result == [feedback_text]
|
||||
mock_add_data.assert_awaited_once()
|
||||
mock_graph_engine.add_edges.assert_awaited_once()
|
||||
mock_index_edges.assert_awaited_once()
|
||||
mock_graph_engine.apply_feedback_weight.assert_awaited_once()
|
||||
|
||||
# Verify add_edges was called with correct relationships
|
||||
call_args = mock_graph_engine.add_edges.call_args[0][0]
|
||||
assert len(call_args) == 2
|
||||
assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text)
|
||||
assert call_args[0][1] == UUID(interaction_id_1)
|
||||
assert call_args[0][2] == "gives_feedback_to"
|
||||
assert call_args[0][3]["relationship_name"] == "gives_feedback_to"
|
||||
assert call_args[0][3]["ontology_valid"] is False
|
||||
|
||||
# Verify apply_feedback_weight was called with correct node IDs
|
||||
weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"]
|
||||
assert len(weight_call_args) == 2
|
||||
assert interaction_id_1 in weight_call_args
|
||||
assert interaction_id_2 in weight_call_args
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_feedback_success_no_relationships(
|
||||
self, mock_feedback_evaluation, mock_graph_engine
|
||||
):
|
||||
"""Test add_feedback successfully creates feedback without relationships."""
|
||||
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
||||
|
||||
feedback_text = "This answer was helpful"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_feedback_evaluation,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_add_data,
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_index_edges,
|
||||
):
|
||||
retriever = UserQAFeedback(last_k=1)
|
||||
result = await retriever.add_feedback(feedback_text)
|
||||
|
||||
assert result == [feedback_text]
|
||||
mock_add_data.assert_awaited_once()
|
||||
# Should not call add_edges or index_graph_edges when no relationships
|
||||
mock_graph_engine.add_edges.assert_not_awaited()
|
||||
mock_index_edges.assert_not_awaited()
|
||||
mock_graph_engine.apply_feedback_weight.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_feedback_creates_correct_feedback_node(
|
||||
self, mock_feedback_evaluation, mock_graph_engine
|
||||
):
|
||||
"""Test add_feedback creates CogneeUserFeedback with correct attributes."""
|
||||
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
||||
|
||||
feedback_text = "This was a negative experience"
|
||||
mock_feedback_evaluation.evaluation.value = "negative"
|
||||
mock_feedback_evaluation.score = -3.0
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_feedback_evaluation,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_add_data,
|
||||
):
|
||||
retriever = UserQAFeedback()
|
||||
await retriever.add_feedback(feedback_text)
|
||||
|
||||
# Verify add_data_points was called with correct CogneeUserFeedback
|
||||
call_args = mock_add_data.call_args[1]["data_points"]
|
||||
assert len(call_args) == 1
|
||||
feedback_node = call_args[0]
|
||||
assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text)
|
||||
assert feedback_node.feedback == feedback_text
|
||||
assert feedback_node.sentiment == "negative"
|
||||
assert feedback_node.score == -3.0
|
||||
assert isinstance(feedback_node.belongs_to_set, NodeSet)
|
||||
assert feedback_node.belongs_to_set.name == "UserQAFeedbacks"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_feedback_calls_llm_with_correct_prompt(
|
||||
self, mock_feedback_evaluation, mock_graph_engine
|
||||
):
|
||||
"""Test add_feedback calls LLM with correct sentiment analysis prompt."""
|
||||
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
||||
|
||||
feedback_text = "Great answer!"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_feedback_evaluation,
|
||||
) as mock_llm,
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
retriever = UserQAFeedback()
|
||||
await retriever.add_feedback(feedback_text)
|
||||
|
||||
mock_llm.assert_awaited_once()
|
||||
call_kwargs = mock_llm.call_args[1]
|
||||
assert call_kwargs["text_input"] == feedback_text
|
||||
assert "sentiment analysis assistant" in call_kwargs["system_prompt"]
|
||||
assert call_kwargs["response_model"] == UserFeedbackEvaluation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_feedback_uses_last_k_parameter(
|
||||
self, mock_feedback_evaluation, mock_graph_engine
|
||||
):
|
||||
"""Test add_feedback uses last_k parameter when getting interaction IDs."""
|
||||
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[])
|
||||
|
||||
feedback_text = "Test feedback"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_feedback_evaluation,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
retriever = UserQAFeedback(last_k=5)
|
||||
await retriever.add_feedback(feedback_text)
|
||||
|
||||
mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_feedback_with_single_interaction(
|
||||
self, mock_feedback_evaluation, mock_graph_engine
|
||||
):
|
||||
"""Test add_feedback with single interaction ID."""
|
||||
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
|
||||
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
|
||||
|
||||
feedback_text = "Test feedback"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_feedback_evaluation,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
retriever = UserQAFeedback()
|
||||
result = await retriever.add_feedback(feedback_text)
|
||||
|
||||
assert result == [feedback_text]
|
||||
# Should create relationship for the interaction
|
||||
call_args = mock_graph_engine.add_edges.call_args[0][0]
|
||||
assert len(call_args) == 1
|
||||
assert call_args[0][1] == UUID(interaction_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_feedback_applies_weight_correctly(
|
||||
self, mock_feedback_evaluation, mock_graph_engine
|
||||
):
|
||||
"""Test add_feedback applies feedback weight with correct score."""
|
||||
interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000"))
|
||||
mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id])
|
||||
mock_feedback_evaluation.score = 4.5
|
||||
|
||||
feedback_text = "Positive feedback"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_feedback_evaluation,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.add_data_points",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.user_qa_feedback.index_graph_edges",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
retriever = UserQAFeedback()
|
||||
await retriever.add_feedback(feedback_text)
|
||||
|
||||
mock_graph_engine.apply_feedback_weight.assert_awaited_once_with(
|
||||
node_ids=[interaction_id], weight=4.5
|
||||
)
|
||||
329
cognee/tests/unit/modules/retrieval/triplet_retriever_test.py
Normal file
329
cognee/tests/unit/modules/retrieval/triplet_retriever_test.py
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.has_collection = AsyncMock(return_value=True)
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_success(mock_vector_engine):
|
||||
"""Test successful retrieval of triplet context."""
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.payload = {"text": "Alice knows Bob"}
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.payload = {"text": "Bob works at Tech Corp"}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
||||
|
||||
retriever = TripletRetriever(top_k=5)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "Alice knows Bob\nBob works at Tech Corp"
|
||||
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_no_collection(mock_vector_engine):
|
||||
"""Test that NoDataError is raised when Triplet_text collection doesn't exist."""
|
||||
mock_vector_engine.has_collection.return_value = False
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(NoDataError, match="create_triplet_embeddings"):
|
||||
await retriever.get_context("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_results(mock_vector_engine):
|
||||
"""Test that empty string is returned when no triplets are found."""
|
||||
mock_vector_engine.search.return_value = []
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_collection_not_found_error(mock_vector_engine):
|
||||
"""Test that CollectionNotFoundError is converted to NoDataError."""
|
||||
mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found")
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(NoDataError, match="No data found"):
|
||||
await retriever.get_context("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_empty_payload_text(mock_vector_engine):
|
||||
"""Test get_context handles missing text in payload."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
with pytest.raises(KeyError):
|
||||
await retriever.get_context("test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_single_triplet(mock_vector_engine):
|
||||
"""Test get_context with single triplet result."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Single triplet"}
|
||||
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "Single triplet"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_defaults():
|
||||
"""Test TripletRetriever initialization with defaults."""
|
||||
retriever = TripletRetriever()
|
||||
|
||||
assert retriever.user_prompt_path == "context_for_question.txt"
|
||||
assert retriever.system_prompt_path == "answer_simple_question.txt"
|
||||
assert retriever.top_k == 5 # Default is 5
|
||||
assert retriever.system_prompt is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_params():
|
||||
"""Test TripletRetriever initialization with custom parameters."""
|
||||
retriever = TripletRetriever(
|
||||
user_prompt_path="custom_user.txt",
|
||||
system_prompt_path="custom_system.txt",
|
||||
system_prompt="Custom prompt",
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert retriever.user_prompt_path == "custom_user.txt"
|
||||
assert retriever.system_prompt_path == "custom_system.txt"
|
||||
assert retriever.system_prompt == "Custom prompt"
|
||||
assert retriever.top_k == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_without_context(mock_vector_engine):
|
||||
"""Test get_completion retrieves context when not provided."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Test triplet"}
|
||||
mock_vector_engine.has_collection.return_value = True
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_provided_context(mock_vector_engine):
|
||||
"""Test get_completion uses provided context."""
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", context="Provided context")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session(mock_vector_engine):
|
||||
"""Test get_completion with session caching enabled."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Test triplet"}
|
||||
mock_vector_engine.has_collection.return_value = True
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-id"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_conversation_history",
|
||||
return_value="Previous conversation",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.summarize_text",
|
||||
return_value="Context summary",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.save_conversation_history",
|
||||
) as mock_save,
|
||||
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
||||
patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = mock_user
|
||||
|
||||
completion = await retriever.get_completion("test query", session_id="test_session")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert completion[0] == "Generated answer"
|
||||
mock_save.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_session_no_user_id(mock_vector_engine):
|
||||
"""Test get_completion with session config but no user ID."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Test triplet"}
|
||||
mock_vector_engine.has_collection.return_value = True
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
||||
return_value="Generated answer",
|
||||
),
|
||||
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
||||
patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = True
|
||||
mock_cache_config.return_value = mock_config
|
||||
mock_session_user.get.return_value = None # No user
|
||||
|
||||
completion = await retriever.get_completion("test query")
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_completion_with_response_model(mock_vector_engine):
|
||||
"""Test get_completion with custom response model."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestModel(BaseModel):
|
||||
answer: str
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.payload = {"text": "Test triplet"}
|
||||
mock_vector_engine.has_collection.return_value = True
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
retriever = TripletRetriever()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.triplet_retriever.generate_completion",
|
||||
return_value=TestModel(answer="Test answer"),
|
||||
),
|
||||
patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config,
|
||||
):
|
||||
mock_config = MagicMock()
|
||||
mock_config.caching = False
|
||||
mock_cache_config.return_value = mock_config
|
||||
|
||||
completion = await retriever.get_completion("test query", response_model=TestModel)
|
||||
|
||||
assert isinstance(completion, list)
|
||||
assert len(completion) == 1
|
||||
assert isinstance(completion[0], TestModel)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_none_top_k():
|
||||
"""Test TripletRetriever initialization with None top_k."""
|
||||
retriever = TripletRetriever(top_k=None)
|
||||
|
||||
assert retriever.top_k == 5
|
||||
Loading…
Add table
Reference in a new issue