From f79ba53e1d97dfcb8843d1343f6f1fcad3c7b31f Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 17 Dec 2025 12:30:15 +0100 Subject: [PATCH] COG-3532 chore: retriever test reorganization + adding new tests (unit) (STEP 2) (#1892) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR restructures/adds unittests for the retrieval module. (STEP 2) -Added missing unit tests for all core retrieval business logic ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **Tests** * Expanded and refactored retrieval module test suites with comprehensive unit test coverage for ChunksRetriever, SummariesRetriever, RagCompletionRetriever, TripletRetriever, GraphCompletionRetriever, TemporalRetriever, and related components. * Added new test modules for completion utilities, graph summary retrieval, and user feedback functionality. * Improved test robustness with edge case handling and error scenario coverage. ✏️ Tip: You can customize this high-level summary in your review settings. --- .../retrieval/chunks_retriever_test.py | 183 ++++ .../retrieval/conversation_history_test.py | 492 +++++++++++ ...letion_retriever_context_extension_test.py | 469 ++++++++++ .../graph_completion_retriever_cot_test.py | 688 +++++++++++++++ .../graph_completion_retriever_test.py | 648 ++++++++++++++ .../rag_completion_retriever_test.py | 321 +++++++ .../retrieval/summaries_retriever_test.py | 193 +++++ .../retrieval/temporal_retriever_test.py | 705 +++++++++++++++ .../test_brute_force_triplet_search.py | 817 ++++++++++++++++++ .../unit/modules/retrieval/test_completion.py | 343 ++++++++ ...test_graph_summary_completion_retriever.py | 157 ++++ .../retrieval/test_user_qa_feedback.py | 312 +++++++ .../retrieval/triplet_retriever_test.py | 329 +++++++ 13 files changed, 5657 insertions(+) create mode 100644 cognee/tests/unit/modules/retrieval/chunks_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/conversation_history_test.py create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/summaries_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/temporal_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py create mode 100644 cognee/tests/unit/modules/retrieval/test_completion.py create mode 100644 cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py create mode 100644 cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py create mode 100644 cognee/tests/unit/modules/retrieval/triplet_retriever_test.py diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py new file mode 100644 index 000000000..98bfd48fe --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -0,0 +1,183 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of chunk context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1} + + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + retriever = ChunksRetriever(top_k=5) + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 2 + assert context[0]["text"] == "Steve Rodger" + assert context[1]["text"] == "Mike Broski" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5) + + +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty list is returned when no chunks are found.""" + mock_vector_engine.search.return_value = [] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(3)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Chunk {i}"} + + mock_vector_engine.search.return_value = mock_results + + retriever = ChunksRetriever(top_k=3) + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 3 + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3) + + +@pytest.mark.asyncio +async def test_get_completion_with_context(mock_vector_engine): + """Test get_completion returns provided context.""" + retriever = ChunksRetriever() + + provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}] + completion = await retriever.get_completion("test query", context=provided_context) + + assert completion == provided_context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Steve Rodger"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query") + + assert len(completion) == 1 + assert completion[0]["text"] == "Steve Rodger" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test ChunksRetriever initialization with defaults.""" + retriever = ChunksRetriever() + + assert retriever.top_k == 5 + + +@pytest.mark.asyncio +async def test_init_custom_top_k(): + """Test ChunksRetriever initialization with custom top_k.""" + retriever = ChunksRetriever(top_k=10) + + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test ChunksRetriever initialization with None top_k.""" + retriever = ChunksRetriever(top_k=None) + + assert retriever.top_k is None + + +@pytest.mark.asyncio +async def test_get_context_empty_payload(mock_vector_engine): + """Test get_context handles empty payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 1 + assert context[0] == {} + + +@pytest.mark.asyncio +async def test_get_completion_with_session_id(mock_vector_engine): + """Test get_completion with session_id parameter.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Steve Rodger"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", session_id="test_session") + + assert len(completion) == 1 + assert completion[0]["text"] == "Steve Rodger" diff --git a/cognee/tests/unit/modules/retrieval/conversation_history_test.py b/cognee/tests/unit/modules/retrieval/conversation_history_test.py new file mode 100644 index 000000000..f1ce9b370 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/conversation_history_test.py @@ -0,0 +1,492 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from cognee.context_global_variables import session_user +import importlib + + +def create_mock_cache_engine(qa_history=None): + mock_cache = AsyncMock() + if qa_history is None: + qa_history = [] + mock_cache.get_latest_qa = AsyncMock(return_value=qa_history) + mock_cache.add_qa = AsyncMock(return_value=None) + return mock_cache + + +def create_mock_user(): + mock_user = MagicMock() + mock_user.id = "test-user-id-123" + return mock_user + + +class TestConversationHistoryUtils: + @pytest.mark.asyncio + async def test_get_conversation_history_returns_empty_when_no_history(self): + user = create_mock_user() + session_user.set(user) + mock_cache = create_mock_cache_engine([]) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + from cognee.modules.retrieval.utils.session_cache import get_conversation_history + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_formats_history_correctly(self): + """Test get_conversation_history formats Q&A history with correct structure.""" + user = create_mock_user() + session_user.set(user) + + mock_history = [ + { + "time": "2024-01-15 10:30:45", + "question": "What is AI?", + "context": "AI is artificial intelligence", + "answer": "AI stands for Artificial Intelligence", + } + ] + mock_cache = create_mock_cache_engine(mock_history) + + # Import the real module to patch safely + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert "Previous conversation:" in result + assert "[2024-01-15 10:30:45]" in result + assert "QUESTION: What is AI?" in result + assert "CONTEXT: AI is artificial intelligence" in result + assert "ANSWER: AI stands for Artificial Intelligence" in result + + @pytest.mark.asyncio + async def test_save_to_session_cache_saves_correctly(self): + """Test save_conversation_history calls add_qa with correct parameters.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="What is Python?", + context_summary="Python is a programming language", + answer="Python is a high-level programming language", + session_id="my_session", + ) + + assert result is True + mock_cache.add_qa.assert_called_once() + + call_kwargs = mock_cache.add_qa.call_args.kwargs + assert call_kwargs["question"] == "What is Python?" + assert call_kwargs["context"] == "Python is a programming language" + assert call_kwargs["answer"] == "Python is a high-level programming language" + assert call_kwargs["session_id"] == "my_session" + + @pytest.mark.asyncio + async def test_save_to_session_cache_uses_default_session_when_none(self): + """Test save_conversation_history uses 'default_session' when session_id is None.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + session_id=None, + ) + + assert result is True + call_kwargs = mock_cache.add_qa.call_args.kwargs + assert call_kwargs["session_id"] == "default_session" + + @pytest.mark.asyncio + async def test_save_conversation_history_no_user_id(self): + """Test save_conversation_history returns False when user_id is None.""" + session_user.set(None) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_caching_disabled(self): + """Test save_conversation_history returns False when caching is disabled.""" + user = create_mock_user() + session_user.set(user) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = False + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_cache_engine_none(self): + """Test save_conversation_history returns False when cache_engine is None.""" + user = create_mock_user() + session_user.set(user) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=None): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_cache_connection_error(self): + """Test save_conversation_history handles CacheConnectionError gracefully.""" + user = create_mock_user() + session_user.set(user) + + from cognee.infrastructure.databases.exceptions import CacheConnectionError + + mock_cache = create_mock_cache_engine([]) + mock_cache.add_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_generic_exception(self): + """Test save_conversation_history handles generic exceptions gracefully.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + mock_cache.add_qa = AsyncMock(side_effect=ValueError("Unexpected error")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_get_conversation_history_no_user_id(self): + """Test get_conversation_history returns empty string when user_id is None.""" + session_user.set(None) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_caching_disabled(self): + """Test get_conversation_history returns empty string when caching is disabled.""" + user = create_mock_user() + session_user.set(user) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = False + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_default_session(self): + """Test get_conversation_history uses 'default_session' when session_id is None.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + await get_conversation_history(session_id=None) + + mock_cache.get_latest_qa.assert_called_once_with(str(user.id), "default_session") + + @pytest.mark.asyncio + async def test_get_conversation_history_cache_engine_none(self): + """Test get_conversation_history returns empty string when cache_engine is None.""" + user = create_mock_user() + session_user.set(user) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=None): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_cache_connection_error(self): + """Test get_conversation_history handles CacheConnectionError gracefully.""" + user = create_mock_user() + session_user.set(user) + + from cognee.infrastructure.databases.exceptions import CacheConnectionError + + mock_cache = create_mock_cache_engine([]) + mock_cache.get_latest_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_generic_exception(self): + """Test get_conversation_history handles generic exceptions gracefully.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + mock_cache.get_latest_qa = AsyncMock(side_effect=ValueError("Unexpected error")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_missing_keys(self): + """Test get_conversation_history handles missing keys in history entries.""" + user = create_mock_user() + session_user.set(user) + + mock_history = [ + { + "time": "2024-01-15 10:30:45", + "question": "What is AI?", + }, + { + "context": "AI is artificial intelligence", + "answer": "AI stands for Artificial Intelligence", + }, + {}, + ] + mock_cache = create_mock_cache_engine(mock_history) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert "Previous conversation:" in result + assert "[2024-01-15 10:30:45]" in result + assert "QUESTION: What is AI?" in result + assert "Unknown time" in result + assert "CONTEXT: AI is artificial intelligence" in result + assert "ANSWER: AI stands for Artificial Intelligence" in result diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py new file mode 100644 index 000000000..6a9b07d38 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -0,0 +1,469 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID + +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +@pytest.mark.asyncio +async def test_get_triplets_inherited(mock_edge): + """Test that get_triplets is inherited from parent class.""" + retriever = GraphCompletionContextExtensionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ): + triplets = await retriever.get_triplets("test query") + + assert len(triplets) == 1 + assert triplets[0] == mock_edge + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionContextExtensionRetriever initialization with defaults.""" + retriever = GraphCompletionContextExtensionRetriever() + + assert retriever.top_k == 5 + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionContextExtensionRetriever initialization with custom parameters.""" + retriever = GraphCompletionContextExtensionRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + node_type=str, + node_name=["node1"], + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=5.0, + ) + + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.node_type is str + assert retriever.node_name == ["node1"] + assert retriever.save_interaction is True + assert retriever.wide_search_top_k == 200 + assert retriever.triplet_distance_penalty == 5.0 + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context_extension_rounds=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_context_extension_rounds(mock_edge): + """Test get_completion with multiple context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + # Create a second edge for extension rounds + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[mock_edge], [mock_edge2]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + side_effect=["Resolved context", "Extended context"], # Different contexts + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Query for extension, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context_extension_rounds=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_context_extension_stops_early(mock_edge): + """Test get_completion stops early when no new triplets found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # When get_context returns same triplets, the loop should stop early + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=4 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "test query", session_id="test_session", context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction(mock_edge): + """Test get_completion with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionContextExtensionRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + TestModel(answer="Test answer"), + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", response_model=TestModel, context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_edge): + """Test get_completion with session config but no user ID.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query", context_extension_rounds=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_zero_extension_rounds(mock_edge): + """Test get_completion with zero context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context_extension_rounds=0) + + assert isinstance(completion, list) + assert len(completion) == 1 diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py new file mode 100644 index 000000000..9f3147512 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -0,0 +1,688 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID + +from cognee.modules.retrieval.graph_completion_cot_retriever import ( + GraphCompletionCotRetriever, + _as_answer_text, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.infrastructure.llm.LLMGateway import LLMGateway + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +@pytest.mark.asyncio +async def test_get_triplets_inherited(mock_edge): + """Test that get_triplets is inherited from parent class.""" + retriever = GraphCompletionCotRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ): + triplets = await retriever.get_triplets("test query") + + assert len(triplets) == 1 + assert triplets[0] == mock_edge + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionCotRetriever initialization with custom parameters.""" + retriever = GraphCompletionCotRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + validation_user_prompt_path="custom_validation_user.txt", + validation_system_prompt_path="custom_validation_system.txt", + followup_system_prompt_path="custom_followup_system.txt", + followup_user_prompt_path="custom_followup_user.txt", + ) + + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.validation_user_prompt_path == "custom_validation_user.txt" + assert retriever.validation_system_prompt_path == "custom_validation_system.txt" + assert retriever.followup_system_prompt_path == "custom_followup_system.txt" + assert retriever.followup_user_prompt_path == "custom_followup_user.txt" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionCotRetriever initialization with defaults.""" + retriever = GraphCompletionCotRetriever() + + assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt" + assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt" + assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt" + assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt" + + +@pytest.mark.asyncio +async def test_run_cot_completion_round_zero_with_context(mock_edge): + """Test _run_cot_completion round 0 with provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + max_iter=1, + ) + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_round_zero_without_context(mock_edge): + """Test _run_cot_completion round 0 without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=None, + max_iter=1, + ) + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_multiple_rounds(mock_edge): + """Test _run_cot_completion with multiple rounds.""" + retriever = GraphCompletionCotRetriever() + + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[mock_edge], [mock_edge2]], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=[ + "validation_result", + "followup_question", + "validation_result2", + "followup_question2", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + max_iter=2, + ) + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_with_conversation_history(mock_edge): + """Test _run_cot_completion with conversation history.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ) as mock_generate, + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + conversation_history="Previous conversation", + max_iter=1, + ) + + assert completion == "Generated answer" + call_kwargs = mock_generate.call_args[1] + assert call_kwargs.get("conversation_history") == "Previous conversation" + + +@pytest.mark.asyncio +async def test_run_cot_completion_with_response_model(mock_edge): + """Test _run_cot_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + response_model=TestModel, + max_iter=1, + ) + + assert isinstance(completion, TestModel) + assert completion.answer == "Test answer" + + +@pytest.mark.asyncio +async def test_run_cot_completion_empty_conversation_history(mock_edge): + """Test _run_cot_completion with empty conversation history.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ) as mock_generate, + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + conversation_history="", + max_iter=1, + ) + + assert completion == "Generated answer" + # Verify conversation_history was passed as None when empty + call_kwargs = mock_generate.call_args[1] + assert call_kwargs.get("conversation_history") is None + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "test query", session_id="test_session", max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction(mock_edge): + """Test get_completion with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionCotRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # Pass context so save_interaction condition is met + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", response_model=TestModel, max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_edge): + """Test get_completion with session config but no user ID.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query", max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_context(mock_edge): + """Test get_completion with save_interaction but no context provided.""" + retriever = GraphCompletionCotRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=None, max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_as_answer_text_with_typeerror(): + """Test _as_answer_text handles TypeError when json.dumps fails.""" + non_serializable = {1, 2, 3} + + result = _as_answer_text(non_serializable) + + assert isinstance(result, str) + assert result == str(non_serializable) + + +@pytest.mark.asyncio +async def test_as_answer_text_with_string(): + """Test _as_answer_text with string input.""" + result = _as_answer_text("test string") + assert result == "test string" + + +@pytest.mark.asyncio +async def test_as_answer_text_with_dict(): + """Test _as_answer_text with dictionary input.""" + test_dict = {"key": "value", "number": 42} + result = _as_answer_text(test_dict) + assert isinstance(result, str) + assert "key" in result + assert "value" in result + + +@pytest.mark.asyncio +async def test_as_answer_text_with_basemodel(): + """Test _as_answer_text with Pydantic BaseModel input.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + test_model = TestModel(answer="test answer") + result = _as_answer_text(test_model) + + assert isinstance(result, str) + assert "[Structured Response]" in result + assert "test answer" in result diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py new file mode 100644 index 000000000..c22f30fd0 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -0,0 +1,648 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID + +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +@pytest.mark.asyncio +async def test_get_triplets_success(mock_edge): + """Test successful retrieval of triplets.""" + retriever = GraphCompletionRetriever(top_k=5) + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ) as mock_search: + triplets = await retriever.get_triplets("test query") + + assert len(triplets) == 1 + assert triplets[0] == mock_edge + mock_search.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_triplets_empty_results(): + """Test that empty list is returned when no triplets are found.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ): + triplets = await retriever.get_triplets("test query") + + assert triplets == [] + + +@pytest.mark.asyncio +async def test_get_triplets_top_k_parameter(): + """Test that top_k parameter is passed to brute_force_triplet_search.""" + retriever = GraphCompletionRetriever(top_k=10) + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ) as mock_search: + await retriever.get_triplets("test query") + + call_kwargs = mock_search.call_args[1] + assert call_kwargs["top_k"] == 10 + + +@pytest.mark.asyncio +async def test_get_context_success(mock_edge): + """Test successful retrieval of context.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + ): + context = await retriever.get_context("test query") + + assert isinstance(context, list) + assert len(context) == 1 + assert context[0] == mock_edge + + +@pytest.mark.asyncio +async def test_get_context_empty_results(): + """Test that empty list is returned when no context is found.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ), + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_empty_graph(): + """Test that empty list is returned when graph is empty.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=True) + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_resolve_edges_to_text(mock_edge): + """Test resolve_edges_to_text method.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved text", + ) as mock_resolve: + result = await retriever.resolve_edges_to_text([mock_edge]) + + assert result == "Resolved text" + mock_resolve.assert_awaited_once_with([mock_edge]) + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionRetriever initialization with defaults.""" + retriever = GraphCompletionRetriever() + + assert retriever.top_k == 5 + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.node_type is None + assert retriever.node_name is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionRetriever initialization with custom parameters.""" + retriever = GraphCompletionRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + node_type=str, + node_name=["node1"], + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=5.0, + ) + + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.node_type is str + assert retriever.node_name == ["node1"] + assert retriever.save_interaction is True + assert retriever.wide_search_top_k == 200 + assert retriever.triplet_distance_penalty == 5.0 + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test GraphCompletionRetriever initialization with None top_k.""" + retriever = GraphCompletionRetriever(top_k=None) + + assert retriever.top_k == 5 # None defaults to 5 + + +@pytest.mark.asyncio +async def test_convert_retrieved_objects_to_context(mock_edge): + """Test convert_retrieved_objects_to_context method.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved text", + ) as mock_resolve: + result = await retriever.convert_retrieved_objects_to_context([mock_edge]) + + assert result == "Resolved text" + mock_resolve.assert_awaited_once_with([mock_edge]) + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge]) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_empty_context(mock_edge): + """Test get_completion with empty context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_save_qa(mock_edge): + """Test save_qa method.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=["uuid1", "uuid2"], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[mock_edge], + ) + + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_save_qa_no_triplet_ids(mock_edge): + """Test save_qa when triplets have no extractable IDs.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + return_value=None, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[mock_edge], + ) + + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_not_called() + + +@pytest.mark.asyncio +async def test_save_qa_empty_triplets(): + """Test save_qa with empty triplets list.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[], + ) + + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_completion(mock_edge): + """Test get_completion with save_interaction but no completion.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=None, # No completion + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] is None + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_context(mock_edge): + """Test get_completion with save_interaction but no context provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=None) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge): + """Test get_completion with save_interaction when all conditions are met (line 216).""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge]) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_add_data.assert_awaited_once() diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py new file mode 100644 index 000000000..e998d419d --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -0,0 +1,321 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Steve Rodger"} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Mike Broski"} + + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + retriever = CompletionRetriever(top_k=2) + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Steve Rodger\nMike Broski" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) + + +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty string is returned when no chunks are found.""" + mock_vector_engine.search.return_value = [] + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "" + + +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(2)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Chunk {i}"} + + mock_vector_engine.search.return_value = mock_results + + retriever = CompletionRetriever(top_k=2) + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Chunk 0\nChunk 1" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) + + +@pytest.mark.asyncio +async def test_get_context_single_chunk(mock_vector_engine): + """Test get_context with single chunk result.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Single chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Single chunk text" + + +@pytest.mark.asyncio +async def test_get_completion_without_session(mock_vector_engine): + """Test get_completion without session caching.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_vector_engine): + """Test get_completion with provided context.""" + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_vector_engine): + """Test get_completion with session caching enabled.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.completion_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.completion_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_vector_engine): + """Test get_completion with session config but no user ID.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test CompletionRetriever initialization with defaults.""" + retriever = CompletionRetriever() + + assert retriever.user_prompt_path == "context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 1 + assert retriever.system_prompt is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test CompletionRetriever initialization with custom parameters.""" + retriever = CompletionRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + top_k=10, + ) + + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_context_missing_text_key(mock_vector_engine): + """Test get_context handles missing text key in payload.""" + mock_result = MagicMock() + mock_result.payload = {"other_key": "value"} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(KeyError): + await retriever.get_context("test query") diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py new file mode 100644 index 000000000..e552ac74a --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -0,0 +1,193 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of summary context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "S.R.", "made_from": "chunk1"} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "M.B.", "made_from": "chunk2"} + + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + retriever = SummariesRetriever(top_k=5) + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 2 + assert context[0]["text"] == "S.R." + assert context[1]["text"] == "M.B." + mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5) + + +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty list is returned when no summaries are found.""" + mock_vector_engine.search.return_value = [] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(3)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Summary {i}"} + + mock_vector_engine.search.return_value = mock_results + + retriever = SummariesRetriever(top_k=3) + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 3 + mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3) + + +@pytest.mark.asyncio +async def test_get_completion_with_context(mock_vector_engine): + """Test get_completion returns provided context.""" + retriever = SummariesRetriever() + + provided_context = [{"text": "S.R."}, {"text": "M.B."}] + completion = await retriever.get_completion("test query", context=provided_context) + + assert completion == provided_context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query") + + assert len(completion) == 1 + assert completion[0]["text"] == "S.R." + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test SummariesRetriever initialization with defaults.""" + retriever = SummariesRetriever() + + assert retriever.top_k == 5 + + +@pytest.mark.asyncio +async def test_init_custom_top_k(): + """Test SummariesRetriever initialization with custom top_k.""" + retriever = SummariesRetriever(top_k=10) + + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_context_empty_payload(mock_vector_engine): + """Test get_context handles empty payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 1 + assert context[0] == {} + + +@pytest.mark.asyncio +async def test_get_completion_with_session_id(mock_vector_engine): + """Test get_completion with session_id parameter.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", session_id="test_session") + + assert len(completion) == 1 + assert completion[0]["text"] == "S.R." + + +@pytest.mark.asyncio +async def test_get_completion_with_kwargs(mock_vector_engine): + """Test get_completion accepts additional kwargs.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", extra_param="value") + + assert len(completion) == 1 diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py new file mode 100644 index 000000000..1d2f4c84d --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -0,0 +1,705 @@ +from types import SimpleNamespace +import pytest +import os +from unittest.mock import AsyncMock, patch, MagicMock +from datetime import datetime + +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp +from cognee.infrastructure.llm import LLMGateway + + +# Test TemporalRetriever initialization defaults and overrides +def test_init_defaults_and_overrides(): + tr = TemporalRetriever() + assert tr.top_k == 5 + assert tr.user_prompt_path == "graph_context_for_question.txt" + assert tr.system_prompt_path == "answer_simple_question.txt" + assert tr.time_extraction_prompt_path == "extract_query_time.txt" + + tr2 = TemporalRetriever( + top_k=3, + user_prompt_path="u.txt", + system_prompt_path="s.txt", + time_extraction_prompt_path="t.txt", + ) + assert tr2.top_k == 3 + assert tr2.user_prompt_path == "u.txt" + assert tr2.system_prompt_path == "s.txt" + assert tr2.time_extraction_prompt_path == "t.txt" + + +# Test descriptions_to_string with basic and empty results +def test_descriptions_to_string_basic_and_empty(): + tr = TemporalRetriever() + + results = [ + {"description": " First "}, + {"nope": "no description"}, + {"description": "Second"}, + {"description": ""}, + {"description": " Third line "}, + ] + + s = tr.descriptions_to_string(results) + assert s == "First\n#####################\nSecond\n#####################\nThird line" + + assert tr.descriptions_to_string([]) == "" + + +# Test filter_top_k_events sorts and limits correctly +@pytest.mark.asyncio +async def test_filter_top_k_events_sorts_and_limits(): + tr = TemporalRetriever(top_k=2) + + relevant_events = [ + { + "events": [ + {"id": "e1", "description": "E1"}, + {"id": "e2", "description": "E2"}, + {"id": "e3", "description": "E3 - not in vector results"}, + ] + } + ] + + scored_results = [ + SimpleNamespace(payload={"id": "e2"}, score=0.10), + SimpleNamespace(payload={"id": "e1"}, score=0.20), + ] + + top = await tr.filter_top_k_events(relevant_events, scored_results) + + assert [e["id"] for e in top] == ["e2", "e1"] + assert all("score" in e for e in top) + assert top[0]["score"] == 0.10 + assert top[1]["score"] == 0.20 + + +# Test filter_top_k_events handles unknown ids as infinite scores +@pytest.mark.asyncio +async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k(): + tr = TemporalRetriever(top_k=2) + + relevant_events = [ + { + "events": [ + {"id": "known1", "description": "Known 1"}, + {"id": "unknown", "description": "Unknown"}, + {"id": "known2", "description": "Known 2"}, + ] + } + ] + + scored_results = [ + SimpleNamespace(payload={"id": "known2"}, score=0.05), + SimpleNamespace(payload={"id": "known1"}, score=0.50), + ] + + top = await tr.filter_top_k_events(relevant_events, scored_results) + assert [e["id"] for e in top] == ["known2", "known1"] + assert all(e["score"] != float("inf") for e in top) + + +# Test descriptions_to_string with unicode and newlines +def test_descriptions_to_string_unicode_and_newlines(): + tr = TemporalRetriever() + results = [ + {"description": "Line A\nwith newline"}, + {"description": "This is a description"}, + ] + s = tr.descriptions_to_string(results) + assert "Line A\nwith newline" in s + assert "This is a description" in s + assert s.count("#####################") == 1 + + +# Test filter_top_k_events when top_k is larger than available events +@pytest.mark.asyncio +async def test_filter_top_k_events_limits_when_top_k_exceeds_events(): + tr = TemporalRetriever(top_k=10) + relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}] + scored_results = [ + SimpleNamespace(payload={"id": "a"}, score=0.1), + SimpleNamespace(payload={"id": "b"}, score=0.2), + ] + out = await tr.filter_top_k_events(relevant_events, scored_results) + assert [e["id"] for e in out] == ["a", "b"] + + +# Test filter_top_k_events when scored_results is empty +@pytest.mark.asyncio +async def test_filter_top_k_events_handles_empty_scored_results(): + tr = TemporalRetriever(top_k=2) + relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}] + scored_results = [] + out = await tr.filter_top_k_events(relevant_events, scored_results) + assert [e["id"] for e in out] == ["x", "y"] + assert all(e["score"] == float("inf") for e in out) + + +# Test filter_top_k_events error handling for missing structure +@pytest.mark.asyncio +async def test_filter_top_k_events_error_handling(): + tr = TemporalRetriever(top_k=2) + with pytest.raises((KeyError, TypeError)): + await tr.filter_top_k_events([{}], []) + + +@pytest.fixture +def mock_graph_engine(): + """Create a mock graph engine.""" + engine = AsyncMock() + engine.collect_time_ids = AsyncMock() + engine.collect_events = AsyncMock() + return engine + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.embedding_engine = AsyncMock() + engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine): + """Test get_context when time range is extracted from query.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + {"id": "e2", "description": "Event 2"}, + ] + } + ] + + mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05) + mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10) + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened in 2024?") + + assert isinstance(context, str) + assert len(context) > 0 + assert "Event" in context + + +@pytest.mark.asyncio +async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine): + """Test get_context falls back to triplets when no time is extracted.""" + retriever = TemporalRetriever() + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] + ) as mock_get_triplets, + patch.object( + retriever, "resolve_edges_to_text", return_value="triplet text" + ) as mock_resolve, + ): + + async def mock_extract_time(query): + return None, None + + retriever.extract_time_from_query = mock_extract_time + + context = await retriever.get_context("test query") + + assert context == "triplet text" + mock_get_triplets.assert_awaited_once_with("test query") + mock_resolve.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_context_no_events_found(mock_graph_engine): + """Test get_context falls back to triplets when no events are found.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = [] + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] + ) as mock_get_triplets, + patch.object( + retriever, "resolve_edges_to_text", return_value="triplet text" + ) as mock_resolve, + ): + + async def mock_extract_time(query): + return "2024-01-01", "2024-12-31" + + retriever.extract_time_from_query = mock_extract_time + + context = await retriever.get_context("test query") + + assert context == "triplet text" + mock_get_triplets.assert_awaited_once_with("test query") + mock_resolve.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine): + """Test get_context with only time_from.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened after 2024?") + + assert isinstance(context, str) + assert "Event 1" in context + + +@pytest.mark.asyncio +async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine): + """Test get_context with only time_to.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened before 2024?") + + assert isinstance(context, str) + assert "Event 1" in context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("What happened in 2024?") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(): + """Test get_completion uses provided context.""" + retriever = TemporalRetriever() + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine): + """Test get_completion with session caching enabled.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "What happened in 2024?", session_id="test_session" + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine): + """Test get_completion with session config but no user ID.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("What happened in 2024?") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_context_retrieved_but_empty(mock_graph_engine): + """Test get_completion when get_context returns empty string.""" + retriever = TemporalRetriever() + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + ) as mock_get_vector, + patch.object(retriever, "filter_top_k_events", return_value=[]), + ): + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + mock_get_vector.return_value = mock_vector_engine + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": ""}, + ] + } + ] + + with pytest.raises((UnboundLocalError, NameError)): + await retriever.get_completion("test query") + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "What happened in 2024?", response_model=TestModel + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_extract_time_from_query_relative_path(): + """Test extract_time_from_query with relative prompt path.""" + retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") + + mock_timestamp_from = Timestamp(year=2024, month=1, day=1) + mock_timestamp_to = Timestamp(year=2024, month=12, day=31) + mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") + + assert time_from == mock_timestamp_from + assert time_to == mock_timestamp_to + + +@pytest.mark.asyncio +async def test_extract_time_from_query_absolute_path(): + """Test extract_time_from_query with absolute prompt path.""" + retriever = TemporalRetriever( + time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt" + ) + + mock_timestamp_from = Timestamp(year=2024, month=1, day=1) + mock_timestamp_to = Timestamp(year=2024, month=12, day=31) + mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True), + patch( + "cognee.modules.retrieval.temporal_retriever.os.path.dirname", + return_value="/absolute/path/to", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.os.path.basename", + return_value="extract_query_time.txt", + ), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") + + assert time_from == mock_timestamp_from + assert time_to == mock_timestamp_to + + +@pytest.mark.asyncio +async def test_extract_time_from_query_with_none_values(): + """Test extract_time_from_query when interval has None values.""" + retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") + + mock_interval = QueryInterval(starts_at=None, ends_at=None) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened?") + + assert time_from is None + assert time_to is None diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py new file mode 100644 index 000000000..b7cbe08d7 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -0,0 +1,817 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.utils.brute_force_triplet_search import ( + brute_force_triplet_search, + get_memory_fragment, + format_triplets, +) +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError +from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_empty_query(): + """Test that empty query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query="") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_none_query(): + """Test that None query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query=None) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_negative_top_k(): + """Test that negative top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=-1) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_zero_top_k(): + """Test that zero top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=0) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_global_search(): + """Test that wide_search_limit is applied for global search (node_name=None).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=None, # Global search + wide_search_top_k=75, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 75 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): + """Test that wide_search_limit is None for filtered search (node_name provided).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=["Node1"], + wide_search_top_k=50, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_default(): + """Test that wide_search_top_k defaults to 100.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", node_name=None) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 100 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_default_collections(): + """Test that default collections are used when none provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test") + + expected_collections = [ + "Entity_name", + "TextSummary_text", + "EntityType_name", + "DocumentChunk_text", + "EdgeType_relationship_name", + ] + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert call_collections == expected_collections + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_custom_collections(): + """Test that custom collections are used when provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + custom_collections = ["CustomCol1", "CustomCol2"] + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", collections=custom_collections) + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_always_includes_edge_collection(): + """Test that EdgeType_relationship_name is always searched even when not in collections.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + collections_without_edge = ["Entity_name", "TextSummary_text"] + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", collections=collections_without_edge) + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert "EdgeType_relationship_name" in call_collections + assert set(call_collections) == set(collections_without_edge) | { + "EdgeType_relationship_name" + } + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_all_collections_empty(): + """Test that empty list is returned when all collections return no results.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + results = await brute_force_triplet_search(query="test") + assert results == [] + + +# Tests for query embedding + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_embeds_query(): + """Test that query is embedded before searching.""" + query_text = "test query" + expected_vector = [0.1, 0.2, 0.3] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query=query_text) + + mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text]) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["query_vector"] == expected_vector + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_extracts_node_ids_global_search(): + """Test that node IDs are extracted from search results for global search.""" + scored_results = [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + MockScoredResult("node3", 0.92), + ] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=scored_results) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_reuses_provided_fragment(): + """Test that provided memory fragment is reused instead of creating new one.""" + provided_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment" + ) as mock_get_fragment, + ): + await brute_force_triplet_search( + query="test", + memory_fragment=provided_fragment, + node_name=["node"], + ) + + mock_get_fragment.assert_not_called() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): + """Test that memory fragment is created when not provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test", node_name=["node"]) + + mock_get_fragment.assert_called_once() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation(): + """Test that custom top_k is passed to importance calculation.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + custom_top_k = 15 + await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) + + mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): + """Test that get_memory_fragment returns empty graph when entity not found (line 85).""" + mock_graph_engine = AsyncMock() + + # Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called + mock_fragment = MagicMock(spec=CogneeGraph) + mock_fragment.project_graph_from_db = AsyncMock( + side_effect=EntityNotFoundError("Entity not found") + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph", + return_value=mock_fragment, + ), + ): + result = await get_memory_fragment() + + # Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85) + assert result == mock_fragment + mock_fragment.project_graph_from_db.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_error(): + """Test that get_memory_fragment returns empty graph on generic error.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error")) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ): + fragment = await get_memory_fragment() + + assert isinstance(fragment, CogneeGraph) + assert len(fragment.nodes) == 0 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_deduplicates_node_ids(): + """Test that duplicate node IDs across collections are deduplicated.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ] + elif collection_name == "TextSummary_text": + return [ + MockScoredResult("node1", 0.90), + MockScoredResult("node3", 0.92), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + assert len(call_kwargs["relevant_ids_to_filter"]) == 3 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_excludes_edge_collection(): + """Test that EdgeType_relationship_name collection is excluded from ID extraction.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "EdgeType_relationship_name": + return [MockScoredResult("edge1", 0.88)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search( + query="test", + node_name=None, + collections=["Entity_name", "EdgeType_relationship_name"], + ) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] == ["node1"] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_skips_nodes_without_ids(): + """Test that nodes without ID attribute are skipped.""" + + class ScoredResultNoId: + """Mock result without id attribute.""" + + def __init__(self, score): + self.score = score + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + ScoredResultNoId(0.90), + MockScoredResult("node2", 0.87), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_handles_tuple_results(): + """Test that both list and tuple results are handled correctly.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return ( + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ) + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_mixed_empty_collections(): + """Test ID extraction with mixed empty and non-empty collections.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "TextSummary_text": + return [] + elif collection_name == "EntityType_name": + return [MockScoredResult("node2", 0.92)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +def test_format_triplets(): + """Test format_triplets function.""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"} + mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"} + mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert isinstance(result, str) + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + assert "connects" in result + + +def test_format_triplets_with_none_values(): + """Test format_triplets filters out None values.""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"} + mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None} + mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + assert "None" not in result or result.count("None") == 0 + + +def test_format_triplets_with_nested_dict(): + """Test format_triplets handles nested dict attributes (lines 23-35).""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}} + mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}} + mock_edge.attributes = {"relationship_name": "relates_to"} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert isinstance(result, str) + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_vector_engine_init_error(): + """Test brute_force_triplet_search handles vector engine initialization error (lines 145-147).""" + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine" + ) as mock_get_vector_engine, + ): + mock_get_vector_engine.side_effect = Exception("Initialization error") + + with pytest.raises(RuntimeError, match="Initialization error"): + await brute_force_triplet_search(query="test query") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_collection_not_found_error(): + """Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_vector_engine.search = AsyncMock( + side_effect=[ + CollectionNotFoundError("Collection not found"), + [], + [], + ] + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=CogneeGraph(), + ), + ): + result = await brute_force_triplet_search( + query="test query", collections=["missing_collection", "existing_collection"] + ) + + assert result == [] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_generic_exception(): + """Test brute_force_triplet_search handles generic exceptions (lines 209-217).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error")) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + with pytest.raises(Exception, match="Generic error"): + await brute_force_triplet_search(query="test query") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none(): + """Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) + mock_vector_engine.search = AsyncMock(return_value=[mock_result]) + + mock_fragment = AsyncMock() + mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() + mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() + mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test query", node_name=["Node1"]) + + assert mock_get_fragment.called + call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {} + assert call_kwargs.get("relevant_ids_to_filter") is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_collection_not_found_at_top_level(): + """Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) + mock_vector_engine.search = AsyncMock(return_value=[mock_result]) + + mock_fragment = AsyncMock() + mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() + mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() + mock_fragment.calculate_top_triplet_importances = AsyncMock( + side_effect=CollectionNotFoundError("Collection not found") + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + result = await brute_force_triplet_search(query="test query") + + assert result == [] diff --git a/cognee/tests/unit/modules/retrieval/test_completion.py b/cognee/tests/unit/modules/retrieval/test_completion.py new file mode 100644 index 000000000..9a836c2cc --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_completion.py @@ -0,0 +1,343 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from typing import Type + + +class TestGenerateCompletion: + @pytest.mark.asyncio + async def test_generate_completion_with_system_prompt(self): + """Test generate_completion with provided system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + system_prompt="Custom system prompt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="Custom system prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_without_system_prompt(self): + """Test generate_completion reads system_prompt from file when not provided.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="System prompt from file", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_conversation_history(self): + """Test generate_completion includes conversation_history in system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning", + ) + + assert result == mock_llm_response + expected_system_prompt = ( + "Previous conversation:\nQ: What is ML?\nA: ML is machine learning" + + "\nTASK:" + + "System prompt from file" + ) + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt=expected_system_prompt, + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self): + """Test generate_completion includes conversation_history with custom system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + system_prompt="Custom system prompt", + conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning", + ) + + assert result == mock_llm_response + expected_system_prompt = ( + "Previous conversation:\nQ: What is ML?\nA: ML is machine learning" + + "\nTASK:" + + "Custom system prompt" + ) + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt=expected_system_prompt, + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_response_model(self): + """Test generate_completion with custom response_model.""" + mock_response_model = MagicMock() + mock_llm_response = {"answer": "Generated answer"} + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + response_model=mock_response_model, + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="System prompt from file", + response_model=mock_response_model, + ) + + @pytest.mark.asyncio + async def test_generate_completion_render_prompt_args(self): + """Test generate_completion passes correct args to render_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ) as mock_render, + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ), + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + ) + + mock_render.assert_called_once_with( + "user_prompt.txt", + {"question": "What is AI?", "context": "AI is artificial intelligence"}, + ) + + +class TestSummarizeText: + @pytest.mark.asyncio + async def test_summarize_text_with_system_prompt(self): + """Test summarize_text with provided system_prompt.""" + mock_llm_response = "Summary text" + + with patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm: + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="summarize_search_results.txt", + system_prompt="Custom summary prompt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Custom summary prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_without_system_prompt(self): + """Test summarize_text reads system_prompt from file when not provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="summarize_search_results.txt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="System prompt from file", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_default_prompt_path(self): + """Test summarize_text uses default prompt path when not provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="Default system prompt", + ) as mock_read, + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text(text="Long text to summarize") + + assert result == mock_llm_response + mock_read.assert_called_once_with("summarize_search_results.txt") + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Default system prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_custom_prompt_path(self): + """Test summarize_text uses custom prompt path when provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="Custom system prompt", + ) as mock_read, + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="custom_prompt.txt", + ) + + assert result == mock_llm_response + mock_read.assert_called_once_with("custom_prompt.txt") + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Custom system prompt", + response_model=str, + ) diff --git a/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py new file mode 100644 index 000000000..2af10da5e --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py @@ -0,0 +1,157 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.graph_summary_completion_retriever import ( + GraphSummaryCompletionRetriever, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +class TestGraphSummaryCompletionRetriever: + @pytest.mark.asyncio + async def test_init_defaults(self): + """Test GraphSummaryCompletionRetriever initialization with defaults.""" + retriever = GraphSummaryCompletionRetriever() + + assert retriever.summarize_prompt_path == "summarize_search_results.txt" + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 5 + assert retriever.save_interaction is False + + @pytest.mark.asyncio + async def test_init_custom_params(self): + """Test GraphSummaryCompletionRetriever initialization with custom parameters.""" + retriever = GraphSummaryCompletionRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + summarize_prompt_path="custom_summarize.txt", + system_prompt="Custom system prompt", + top_k=10, + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=2.5, + ) + + assert retriever.summarize_prompt_path == "custom_summarize.txt" + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.top_k == 10 + assert retriever.save_interaction is True + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_calls_super_and_summarizes(self, mock_edge): + """Test resolve_edges_to_text calls super method and then summarizes.""" + retriever = GraphSummaryCompletionRetriever( + summarize_prompt_path="custom_summarize.txt", + system_prompt="Custom system prompt", + ) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Resolved edges text", + ) as mock_super_resolve, + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Summarized text", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([mock_edge]) + + assert result == "Summarized text" + mock_super_resolve.assert_awaited_once_with([mock_edge]) + mock_summarize.assert_awaited_once_with( + "Resolved edges text", + "custom_summarize.txt", + "Custom system prompt", + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_default_system_prompt(self, mock_edge): + """Test resolve_edges_to_text uses None for system_prompt when not provided.""" + retriever = GraphSummaryCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Resolved edges text", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Summarized text", + ) as mock_summarize, + ): + await retriever.resolve_edges_to_text([mock_edge]) + + mock_summarize.assert_awaited_once_with( + "Resolved edges text", + "summarize_search_results.txt", + None, + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_empty_edges(self): + """Test resolve_edges_to_text handles empty edges list.""" + retriever = GraphSummaryCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Empty summary", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([]) + + assert result == "Empty summary" + mock_summarize.assert_awaited_once_with( + "", + "summarize_search_results.txt", + None, + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_multiple_edges(self, mock_edge): + """Test resolve_edges_to_text handles multiple edges.""" + retriever = GraphSummaryCompletionRetriever() + + mock_edge2 = MagicMock(spec=Edge) + mock_edge3 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Multiple edges resolved text", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Multiple edges summarized", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([mock_edge, mock_edge2, mock_edge3]) + + assert result == "Multiple edges summarized" + mock_summarize.assert_awaited_once_with( + "Multiple edges resolved text", + "summarize_search_results.txt", + None, + ) diff --git a/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py new file mode 100644 index 000000000..a1e746bb9 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py @@ -0,0 +1,312 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID, NAMESPACE_OID, uuid5 + +from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback +from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment +from cognee.modules.engine.models import NodeSet + + +@pytest.fixture +def mock_feedback_evaluation(): + """Create a mock feedback evaluation.""" + evaluation = MagicMock(spec=UserFeedbackEvaluation) + evaluation.evaluation = MagicMock() + evaluation.evaluation.value = "positive" + evaluation.score = 4.5 + return evaluation + + +@pytest.fixture +def mock_graph_engine(): + """Create a mock graph engine.""" + engine = AsyncMock() + engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + engine.add_edges = AsyncMock() + engine.apply_feedback_weight = AsyncMock() + return engine + + +class TestUserQAFeedback: + @pytest.mark.asyncio + async def test_init_default(self): + """Test UserQAFeedback initialization with default last_k.""" + retriever = UserQAFeedback() + assert retriever.last_k == 1 + + @pytest.mark.asyncio + async def test_init_custom_last_k(self): + """Test UserQAFeedback initialization with custom last_k.""" + retriever = UserQAFeedback(last_k=5) + assert retriever.last_k == 5 + + @pytest.mark.asyncio + async def test_add_feedback_success_with_relationships( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback successfully creates feedback with relationships.""" + interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock( + return_value=[interaction_id_1, interaction_id_2] + ) + + feedback_text = "This answer was helpful" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ) as mock_index_edges, + ): + retriever = UserQAFeedback(last_k=2) + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_awaited_once() + mock_index_edges.assert_awaited_once() + mock_graph_engine.apply_feedback_weight.assert_awaited_once() + + # Verify add_edges was called with correct relationships + call_args = mock_graph_engine.add_edges.call_args[0][0] + assert len(call_args) == 2 + assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text) + assert call_args[0][1] == UUID(interaction_id_1) + assert call_args[0][2] == "gives_feedback_to" + assert call_args[0][3]["relationship_name"] == "gives_feedback_to" + assert call_args[0][3]["ontology_valid"] is False + + # Verify apply_feedback_weight was called with correct node IDs + weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"] + assert len(weight_call_args) == 2 + assert interaction_id_1 in weight_call_args + assert interaction_id_2 in weight_call_args + + @pytest.mark.asyncio + async def test_add_feedback_success_no_relationships( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback successfully creates feedback without relationships.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "This answer was helpful" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ) as mock_index_edges, + ): + retriever = UserQAFeedback(last_k=1) + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + mock_add_data.assert_awaited_once() + # Should not call add_edges or index_graph_edges when no relationships + mock_graph_engine.add_edges.assert_not_awaited() + mock_index_edges.assert_not_awaited() + mock_graph_engine.apply_feedback_weight.assert_not_awaited() + + @pytest.mark.asyncio + async def test_add_feedback_creates_correct_feedback_node( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback creates CogneeUserFeedback with correct attributes.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "This was a negative experience" + mock_feedback_evaluation.evaluation.value = "negative" + mock_feedback_evaluation.score = -3.0 + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + # Verify add_data_points was called with correct CogneeUserFeedback + call_args = mock_add_data.call_args[1]["data_points"] + assert len(call_args) == 1 + feedback_node = call_args[0] + assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text) + assert feedback_node.feedback == feedback_text + assert feedback_node.sentiment == "negative" + assert feedback_node.score == -3.0 + assert isinstance(feedback_node.belongs_to_set, NodeSet) + assert feedback_node.belongs_to_set.name == "UserQAFeedbacks" + + @pytest.mark.asyncio + async def test_add_feedback_calls_llm_with_correct_prompt( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback calls LLM with correct sentiment analysis prompt.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "Great answer!" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ) as mock_llm, + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + mock_llm.assert_awaited_once() + call_kwargs = mock_llm.call_args[1] + assert call_kwargs["text_input"] == feedback_text + assert "sentiment analysis assistant" in call_kwargs["system_prompt"] + assert call_kwargs["response_model"] == UserFeedbackEvaluation + + @pytest.mark.asyncio + async def test_add_feedback_uses_last_k_parameter( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback uses last_k parameter when getting interaction IDs.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "Test feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback(last_k=5) + await retriever.add_feedback(feedback_text) + + mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5) + + @pytest.mark.asyncio + async def test_add_feedback_with_single_interaction( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback with single interaction ID.""" + interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id]) + + feedback_text = "Test feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + # Should create relationship for the interaction + call_args = mock_graph_engine.add_edges.call_args[0][0] + assert len(call_args) == 1 + assert call_args[0][1] == UUID(interaction_id) + + @pytest.mark.asyncio + async def test_add_feedback_applies_weight_correctly( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback applies feedback weight with correct score.""" + interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id]) + mock_feedback_evaluation.score = 4.5 + + feedback_text = "Positive feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + mock_graph_engine.apply_feedback_weight.assert_awaited_once_with( + node_ids=[interaction_id], weight=4.5 + ) diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py new file mode 100644 index 000000000..83612c7aa --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py @@ -0,0 +1,329 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.triplet_retriever import TripletRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.has_collection = AsyncMock(return_value=True) + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of triplet context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Alice knows Bob"} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Bob works at Tech Corp"} + + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + retriever = TripletRetriever(top_k=5) + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Alice knows Bob\nBob works at Tech Corp" + mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5) + + +@pytest.mark.asyncio +async def test_get_context_no_collection(mock_vector_engine): + """Test that NoDataError is raised when Triplet_text collection doesn't exist.""" + mock_vector_engine.has_collection.return_value = False + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="create_triplet_embeddings"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty string is returned when no triplets are found.""" + mock_vector_engine.search.return_value = [] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "" + + +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found") + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_payload_text(mock_vector_engine): + """Test get_context handles missing text in payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(KeyError): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_single_triplet(mock_vector_engine): + """Test get_context with single triplet result.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Single triplet"} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Single triplet" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test TripletRetriever initialization with defaults.""" + retriever = TripletRetriever() + + assert retriever.user_prompt_path == "context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 5 # Default is 5 + assert retriever.system_prompt is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test TripletRetriever initialization with custom parameters.""" + retriever = TripletRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + top_k=10, + ) + + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_vector_engine): + """Test get_completion uses provided context.""" + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_vector_engine): + """Test get_completion with session caching enabled.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_vector_engine): + """Test get_completion with session config but no user ID.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test TripletRetriever initialization with None top_k.""" + retriever = TripletRetriever(top_k=None) + + assert retriever.top_k == 5