diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py deleted file mode 100644 index 98bfd48fe..000000000 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ /dev/null @@ -1,183 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock - -from cognee.modules.retrieval.chunks_retriever import ChunksRetriever -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError - - -@pytest.fixture -def mock_vector_engine(): - """Create a mock vector engine.""" - engine = AsyncMock() - engine.search = AsyncMock() - return engine - - -@pytest.mark.asyncio -async def test_get_context_success(mock_vector_engine): - """Test successful retrieval of chunk context.""" - mock_result1 = MagicMock() - mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0} - mock_result2 = MagicMock() - mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1} - - mock_vector_engine.search.return_value = [mock_result1, mock_result2] - - retriever = ChunksRetriever(top_k=5) - - with patch( - "cognee.modules.retrieval.chunks_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert len(context) == 2 - assert context[0]["text"] == "Steve Rodger" - assert context[1]["text"] == "Mike Broski" - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5) - - -@pytest.mark.asyncio -async def test_get_context_collection_not_found_error(mock_vector_engine): - """Test that CollectionNotFoundError is converted to NoDataError.""" - mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") - - retriever = ChunksRetriever() - - with patch( - "cognee.modules.retrieval.chunks_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="No data found"): - await retriever.get_context("test query") - - -@pytest.mark.asyncio -async def test_get_context_empty_results(mock_vector_engine): - """Test that empty list is returned when no chunks are found.""" - mock_vector_engine.search.return_value = [] - - retriever = ChunksRetriever() - - with patch( - "cognee.modules.retrieval.chunks_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == [] - - -@pytest.mark.asyncio -async def test_get_context_top_k_limit(mock_vector_engine): - """Test that top_k parameter limits the number of results.""" - mock_results = [MagicMock() for _ in range(3)] - for i, result in enumerate(mock_results): - result.payload = {"text": f"Chunk {i}"} - - mock_vector_engine.search.return_value = mock_results - - retriever = ChunksRetriever(top_k=3) - - with patch( - "cognee.modules.retrieval.chunks_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert len(context) == 3 - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3) - - -@pytest.mark.asyncio -async def test_get_completion_with_context(mock_vector_engine): - """Test get_completion returns provided context.""" - retriever = ChunksRetriever() - - provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}] - completion = await retriever.get_completion("test query", context=provided_context) - - assert completion == provided_context - - -@pytest.mark.asyncio -async def test_get_completion_without_context(mock_vector_engine): - """Test get_completion retrieves context when not provided.""" - mock_result = MagicMock() - mock_result.payload = {"text": "Steve Rodger"} - mock_vector_engine.search.return_value = [mock_result] - - retriever = ChunksRetriever() - - with patch( - "cognee.modules.retrieval.chunks_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - completion = await retriever.get_completion("test query") - - assert len(completion) == 1 - assert completion[0]["text"] == "Steve Rodger" - - -@pytest.mark.asyncio -async def test_init_defaults(): - """Test ChunksRetriever initialization with defaults.""" - retriever = ChunksRetriever() - - assert retriever.top_k == 5 - - -@pytest.mark.asyncio -async def test_init_custom_top_k(): - """Test ChunksRetriever initialization with custom top_k.""" - retriever = ChunksRetriever(top_k=10) - - assert retriever.top_k == 10 - - -@pytest.mark.asyncio -async def test_init_none_top_k(): - """Test ChunksRetriever initialization with None top_k.""" - retriever = ChunksRetriever(top_k=None) - - assert retriever.top_k is None - - -@pytest.mark.asyncio -async def test_get_context_empty_payload(mock_vector_engine): - """Test get_context handles empty payload.""" - mock_result = MagicMock() - mock_result.payload = {} - - mock_vector_engine.search.return_value = [mock_result] - - retriever = ChunksRetriever() - - with patch( - "cognee.modules.retrieval.chunks_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert len(context) == 1 - assert context[0] == {} - - -@pytest.mark.asyncio -async def test_get_completion_with_session_id(mock_vector_engine): - """Test get_completion with session_id parameter.""" - mock_result = MagicMock() - mock_result.payload = {"text": "Steve Rodger"} - mock_vector_engine.search.return_value = [mock_result] - - retriever = ChunksRetriever() - - with patch( - "cognee.modules.retrieval.chunks_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - completion = await retriever.get_completion("test query", session_id="test_session") - - assert len(completion) == 1 - assert completion[0]["text"] == "Steve Rodger" diff --git a/cognee/tests/unit/modules/retrieval/conversation_history_test.py b/cognee/tests/unit/modules/retrieval/conversation_history_test.py deleted file mode 100644 index f1ce9b370..000000000 --- a/cognee/tests/unit/modules/retrieval/conversation_history_test.py +++ /dev/null @@ -1,492 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from cognee.context_global_variables import session_user -import importlib - - -def create_mock_cache_engine(qa_history=None): - mock_cache = AsyncMock() - if qa_history is None: - qa_history = [] - mock_cache.get_latest_qa = AsyncMock(return_value=qa_history) - mock_cache.add_qa = AsyncMock(return_value=None) - return mock_cache - - -def create_mock_user(): - mock_user = MagicMock() - mock_user.id = "test-user-id-123" - return mock_user - - -class TestConversationHistoryUtils: - @pytest.mark.asyncio - async def test_get_conversation_history_returns_empty_when_no_history(self): - user = create_mock_user() - session_user.set(user) - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - from cognee.modules.retrieval.utils.session_cache import get_conversation_history - - result = await get_conversation_history(session_id="test_session") - - assert result == "" - - @pytest.mark.asyncio - async def test_get_conversation_history_formats_history_correctly(self): - """Test get_conversation_history formats Q&A history with correct structure.""" - user = create_mock_user() - session_user.set(user) - - mock_history = [ - { - "time": "2024-01-15 10:30:45", - "question": "What is AI?", - "context": "AI is artificial intelligence", - "answer": "AI stands for Artificial Intelligence", - } - ] - mock_cache = create_mock_cache_engine(mock_history) - - # Import the real module to patch safely - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - result = await get_conversation_history(session_id="test_session") - - assert "Previous conversation:" in result - assert "[2024-01-15 10:30:45]" in result - assert "QUESTION: What is AI?" in result - assert "CONTEXT: AI is artificial intelligence" in result - assert "ANSWER: AI stands for Artificial Intelligence" in result - - @pytest.mark.asyncio - async def test_save_to_session_cache_saves_correctly(self): - """Test save_conversation_history calls add_qa with correct parameters.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="What is Python?", - context_summary="Python is a programming language", - answer="Python is a high-level programming language", - session_id="my_session", - ) - - assert result is True - mock_cache.add_qa.assert_called_once() - - call_kwargs = mock_cache.add_qa.call_args.kwargs - assert call_kwargs["question"] == "What is Python?" - assert call_kwargs["context"] == "Python is a programming language" - assert call_kwargs["answer"] == "Python is a high-level programming language" - assert call_kwargs["session_id"] == "my_session" - - @pytest.mark.asyncio - async def test_save_to_session_cache_uses_default_session_when_none(self): - """Test save_conversation_history uses 'default_session' when session_id is None.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="Test question", - context_summary="Test context", - answer="Test answer", - session_id=None, - ) - - assert result is True - call_kwargs = mock_cache.add_qa.call_args.kwargs - assert call_kwargs["session_id"] == "default_session" - - @pytest.mark.asyncio - async def test_save_conversation_history_no_user_id(self): - """Test save_conversation_history returns False when user_id is None.""" - session_user.set(None) - - with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="Test question", - context_summary="Test context", - answer="Test answer", - ) - - assert result is False - - @pytest.mark.asyncio - async def test_save_conversation_history_caching_disabled(self): - """Test save_conversation_history returns False when caching is disabled.""" - user = create_mock_user() - session_user.set(user) - - with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = False - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="Test question", - context_summary="Test context", - answer="Test answer", - ) - - assert result is False - - @pytest.mark.asyncio - async def test_save_conversation_history_cache_engine_none(self): - """Test save_conversation_history returns False when cache_engine is None.""" - user = create_mock_user() - session_user.set(user) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=None): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="Test question", - context_summary="Test context", - answer="Test answer", - ) - - assert result is False - - @pytest.mark.asyncio - async def test_save_conversation_history_cache_connection_error(self): - """Test save_conversation_history handles CacheConnectionError gracefully.""" - user = create_mock_user() - session_user.set(user) - - from cognee.infrastructure.databases.exceptions import CacheConnectionError - - mock_cache = create_mock_cache_engine([]) - mock_cache.add_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed")) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="Test question", - context_summary="Test context", - answer="Test answer", - ) - - assert result is False - - @pytest.mark.asyncio - async def test_save_conversation_history_generic_exception(self): - """Test save_conversation_history handles generic exceptions gracefully.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - mock_cache.add_qa = AsyncMock(side_effect=ValueError("Unexpected error")) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="Test question", - context_summary="Test context", - answer="Test answer", - ) - - assert result is False - - @pytest.mark.asyncio - async def test_get_conversation_history_no_user_id(self): - """Test get_conversation_history returns empty string when user_id is None.""" - session_user.set(None) - - with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - result = await get_conversation_history(session_id="test_session") - - assert result == "" - - @pytest.mark.asyncio - async def test_get_conversation_history_caching_disabled(self): - """Test get_conversation_history returns empty string when caching is disabled.""" - user = create_mock_user() - session_user.set(user) - - with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = False - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - result = await get_conversation_history(session_id="test_session") - - assert result == "" - - @pytest.mark.asyncio - async def test_get_conversation_history_default_session(self): - """Test get_conversation_history uses 'default_session' when session_id is None.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - await get_conversation_history(session_id=None) - - mock_cache.get_latest_qa.assert_called_once_with(str(user.id), "default_session") - - @pytest.mark.asyncio - async def test_get_conversation_history_cache_engine_none(self): - """Test get_conversation_history returns empty string when cache_engine is None.""" - user = create_mock_user() - session_user.set(user) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=None): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - result = await get_conversation_history(session_id="test_session") - - assert result == "" - - @pytest.mark.asyncio - async def test_get_conversation_history_cache_connection_error(self): - """Test get_conversation_history handles CacheConnectionError gracefully.""" - user = create_mock_user() - session_user.set(user) - - from cognee.infrastructure.databases.exceptions import CacheConnectionError - - mock_cache = create_mock_cache_engine([]) - mock_cache.get_latest_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed")) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - result = await get_conversation_history(session_id="test_session") - - assert result == "" - - @pytest.mark.asyncio - async def test_get_conversation_history_generic_exception(self): - """Test get_conversation_history handles generic exceptions gracefully.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - mock_cache.get_latest_qa = AsyncMock(side_effect=ValueError("Unexpected error")) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - result = await get_conversation_history(session_id="test_session") - - assert result == "" - - @pytest.mark.asyncio - async def test_get_conversation_history_missing_keys(self): - """Test get_conversation_history handles missing keys in history entries.""" - user = create_mock_user() - session_user.set(user) - - mock_history = [ - { - "time": "2024-01-15 10:30:45", - "question": "What is AI?", - }, - { - "context": "AI is artificial intelligence", - "answer": "AI stands for Artificial Intelligence", - }, - {}, - ] - mock_cache = create_mock_cache_engine(mock_history) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - result = await get_conversation_history(session_id="test_session") - - assert "Previous conversation:" in result - assert "[2024-01-15 10:30:45]" in result - assert "QUESTION: What is AI?" in result - assert "Unknown time" in result - assert "CONTEXT: AI is artificial intelligence" in result - assert "ANSWER: AI stands for Artificial Intelligence" in result diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py deleted file mode 100644 index 6a9b07d38..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ /dev/null @@ -1,469 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from uuid import UUID - -from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( - GraphCompletionContextExtensionRetriever, -) -from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge - - -@pytest.fixture -def mock_edge(): - """Create a mock edge.""" - edge = MagicMock(spec=Edge) - return edge - - -@pytest.mark.asyncio -async def test_get_triplets_inherited(mock_edge): - """Test that get_triplets is inherited from parent class.""" - retriever = GraphCompletionContextExtensionRetriever() - - with patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ): - triplets = await retriever.get_triplets("test query") - - assert len(triplets) == 1 - assert triplets[0] == mock_edge - - -@pytest.mark.asyncio -async def test_init_defaults(): - """Test GraphCompletionContextExtensionRetriever initialization with defaults.""" - retriever = GraphCompletionContextExtensionRetriever() - - assert retriever.top_k == 5 - assert retriever.user_prompt_path == "graph_context_for_question.txt" - assert retriever.system_prompt_path == "answer_simple_question.txt" - - -@pytest.mark.asyncio -async def test_init_custom_params(): - """Test GraphCompletionContextExtensionRetriever initialization with custom parameters.""" - retriever = GraphCompletionContextExtensionRetriever( - top_k=10, - user_prompt_path="custom_user.txt", - system_prompt_path="custom_system.txt", - system_prompt="Custom prompt", - node_type=str, - node_name=["node1"], - save_interaction=True, - wide_search_top_k=200, - triplet_distance_penalty=5.0, - ) - - assert retriever.top_k == 10 - assert retriever.user_prompt_path == "custom_user.txt" - assert retriever.system_prompt_path == "custom_system.txt" - assert retriever.system_prompt == "Custom prompt" - assert retriever.node_type is str - assert retriever.node_name == ["node1"] - assert retriever.save_interaction is True - assert retriever.wide_search_top_k == 200 - assert retriever.triplet_distance_penalty == 5.0 - - -@pytest.mark.asyncio -async def test_get_completion_without_context(mock_edge): - """Test get_completion retrieves context when not provided.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionContextExtensionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context_extension_rounds=1) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_provided_context(mock_edge): - """Test get_completion uses provided context.""" - retriever = GraphCompletionContextExtensionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion( - "test query", context=[mock_edge], context_extension_rounds=1 - ) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_context_extension_rounds(mock_edge): - """Test get_completion with multiple context extension rounds.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionContextExtensionRetriever() - - # Create a second edge for extension rounds - mock_edge2 = MagicMock(spec=Edge) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch.object( - retriever, - "get_context", - new_callable=AsyncMock, - side_effect=[[mock_edge], [mock_edge2]], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - side_effect=["Resolved context", "Extended context"], # Different contexts - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - side_effect=[ - "Extension query", - "Generated answer", - ], # Query for extension, then final answer - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context_extension_rounds=1) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_context_extension_stops_early(mock_edge): - """Test get_completion stops early when no new triplets found.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionContextExtensionRetriever() - - with ( - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - side_effect=[ - "Extension query", - "Generated answer", - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - # When get_context returns same triplets, the loop should stop early - completion = await retriever.get_completion( - "test query", context=[mock_edge], context_extension_rounds=4 - ) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_session(mock_edge): - """Test get_completion with session caching enabled.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionContextExtensionRetriever() - - mock_user = MagicMock() - mock_user.id = "test-user-id" - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history", - return_value="Previous conversation", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text", - return_value="Context summary", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - side_effect=[ - "Extension query", - "Generated answer", - ], # Extension query, then final answer - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history", - ) as mock_save, - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user" - ) as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = mock_user - - completion = await retriever.get_completion( - "test query", session_id="test_session", context_extension_rounds=1 - ) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - mock_save.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_completion_with_save_interaction(mock_edge): - """Test get_completion with save_interaction enabled.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - mock_graph_engine.add_edges = AsyncMock() - - retriever = GraphCompletionContextExtensionRetriever(save_interaction=True) - - mock_node1 = MagicMock() - mock_node2 = MagicMock() - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - side_effect=[ - "Extension query", - "Generated answer", - ], # Extension query, then final answer - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", - side_effect=[ - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.add_data_points", - ) as mock_add_data, - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion( - "test query", context=[mock_edge], context_extension_rounds=1 - ) - - assert isinstance(completion, list) - assert len(completion) == 1 - mock_add_data.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_completion_with_response_model(mock_edge): - """Test get_completion with custom response model.""" - from pydantic import BaseModel - - class TestModel(BaseModel): - answer: str - - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionContextExtensionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - side_effect=[ - "Extension query", - TestModel(answer="Test answer"), - ], # Extension query, then final answer - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion( - "test query", response_model=TestModel, context_extension_rounds=1 - ) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert isinstance(completion[0], TestModel) - - -@pytest.mark.asyncio -async def test_get_completion_with_session_no_user_id(mock_edge): - """Test get_completion with session config but no user ID.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionContextExtensionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - side_effect=[ - "Extension query", - "Generated answer", - ], # Extension query, then final answer - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user" - ) as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = None # No user - - completion = await retriever.get_completion("test query", context_extension_rounds=1) - - assert isinstance(completion, list) - assert len(completion) == 1 - - -@pytest.mark.asyncio -async def test_get_completion_zero_extension_rounds(mock_edge): - """Test get_completion with zero context extension rounds.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionContextExtensionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context_extension_rounds=0) - - assert isinstance(completion, list) - assert len(completion) == 1 diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py deleted file mode 100644 index 9f3147512..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ /dev/null @@ -1,688 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from uuid import UUID - -from cognee.modules.retrieval.graph_completion_cot_retriever import ( - GraphCompletionCotRetriever, - _as_answer_text, -) -from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge -from cognee.infrastructure.llm.LLMGateway import LLMGateway - - -@pytest.fixture -def mock_edge(): - """Create a mock edge.""" - edge = MagicMock(spec=Edge) - return edge - - -@pytest.mark.asyncio -async def test_get_triplets_inherited(mock_edge): - """Test that get_triplets is inherited from parent class.""" - retriever = GraphCompletionCotRetriever() - - with patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ): - triplets = await retriever.get_triplets("test query") - - assert len(triplets) == 1 - assert triplets[0] == mock_edge - - -@pytest.mark.asyncio -async def test_init_custom_params(): - """Test GraphCompletionCotRetriever initialization with custom parameters.""" - retriever = GraphCompletionCotRetriever( - top_k=10, - user_prompt_path="custom_user.txt", - system_prompt_path="custom_system.txt", - validation_user_prompt_path="custom_validation_user.txt", - validation_system_prompt_path="custom_validation_system.txt", - followup_system_prompt_path="custom_followup_system.txt", - followup_user_prompt_path="custom_followup_user.txt", - ) - - assert retriever.top_k == 10 - assert retriever.user_prompt_path == "custom_user.txt" - assert retriever.system_prompt_path == "custom_system.txt" - assert retriever.validation_user_prompt_path == "custom_validation_user.txt" - assert retriever.validation_system_prompt_path == "custom_validation_system.txt" - assert retriever.followup_system_prompt_path == "custom_followup_system.txt" - assert retriever.followup_user_prompt_path == "custom_followup_user.txt" - - -@pytest.mark.asyncio -async def test_init_defaults(): - """Test GraphCompletionCotRetriever initialization with defaults.""" - retriever = GraphCompletionCotRetriever() - - assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt" - assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt" - assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt" - assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt" - - -@pytest.mark.asyncio -async def test_run_cot_completion_round_zero_with_context(mock_edge): - """Test _run_cot_completion round 0 with provided context.""" - retriever = GraphCompletionCotRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", - return_value="Rendered prompt", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", - return_value="System prompt", - ), - patch.object( - LLMGateway, - "acreate_structured_output", - new_callable=AsyncMock, - side_effect=["validation_result", "followup_question"], - ), - ): - completion, context_text, triplets = await retriever._run_cot_completion( - query="test query", - context=[mock_edge], - max_iter=1, - ) - - assert completion == "Generated answer" - assert context_text == "Resolved context" - assert len(triplets) >= 1 - - -@pytest.mark.asyncio -async def test_run_cot_completion_round_zero_without_context(mock_edge): - """Test _run_cot_completion round 0 without provided context.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionCotRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - ): - completion, context_text, triplets = await retriever._run_cot_completion( - query="test query", - context=None, - max_iter=1, - ) - - assert completion == "Generated answer" - assert context_text == "Resolved context" - assert len(triplets) >= 1 - - -@pytest.mark.asyncio -async def test_run_cot_completion_multiple_rounds(mock_edge): - """Test _run_cot_completion with multiple rounds.""" - retriever = GraphCompletionCotRetriever() - - mock_edge2 = MagicMock(spec=Edge) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - patch.object( - retriever, - "get_context", - new_callable=AsyncMock, - side_effect=[[mock_edge], [mock_edge2]], - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", - return_value="Rendered prompt", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", - return_value="System prompt", - ), - patch.object( - LLMGateway, - "acreate_structured_output", - new_callable=AsyncMock, - side_effect=[ - "validation_result", - "followup_question", - "validation_result2", - "followup_question2", - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", - return_value="Generated answer", - ), - ): - completion, context_text, triplets = await retriever._run_cot_completion( - query="test query", - context=[mock_edge], - max_iter=2, - ) - - assert completion == "Generated answer" - assert context_text == "Resolved context" - assert len(triplets) >= 1 - - -@pytest.mark.asyncio -async def test_run_cot_completion_with_conversation_history(mock_edge): - """Test _run_cot_completion with conversation history.""" - retriever = GraphCompletionCotRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ) as mock_generate, - ): - completion, context_text, triplets = await retriever._run_cot_completion( - query="test query", - context=[mock_edge], - conversation_history="Previous conversation", - max_iter=1, - ) - - assert completion == "Generated answer" - call_kwargs = mock_generate.call_args[1] - assert call_kwargs.get("conversation_history") == "Previous conversation" - - -@pytest.mark.asyncio -async def test_run_cot_completion_with_response_model(mock_edge): - """Test _run_cot_completion with custom response model.""" - from pydantic import BaseModel - - class TestModel(BaseModel): - answer: str - - retriever = GraphCompletionCotRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value=TestModel(answer="Test answer"), - ), - ): - completion, context_text, triplets = await retriever._run_cot_completion( - query="test query", - context=[mock_edge], - response_model=TestModel, - max_iter=1, - ) - - assert isinstance(completion, TestModel) - assert completion.answer == "Test answer" - - -@pytest.mark.asyncio -async def test_run_cot_completion_empty_conversation_history(mock_edge): - """Test _run_cot_completion with empty conversation history.""" - retriever = GraphCompletionCotRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ) as mock_generate, - ): - completion, context_text, triplets = await retriever._run_cot_completion( - query="test query", - context=[mock_edge], - conversation_history="", - max_iter=1, - ) - - assert completion == "Generated answer" - # Verify conversation_history was passed as None when empty - call_kwargs = mock_generate.call_args[1] - assert call_kwargs.get("conversation_history") is None - - -@pytest.mark.asyncio -async def test_get_completion_without_context(mock_edge): - """Test get_completion retrieves context when not provided.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionCotRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", - return_value="Rendered prompt", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", - return_value="System prompt", - ), - patch.object( - LLMGateway, - "acreate_structured_output", - new_callable=AsyncMock, - side_effect=["validation_result", "followup_question"], - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", max_iter=1) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_provided_context(mock_edge): - """Test get_completion uses provided context.""" - retriever = GraphCompletionCotRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_session(mock_edge): - """Test get_completion with session caching enabled.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionCotRetriever() - - mock_user = MagicMock() - mock_user.id = "test-user-id" - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history", - return_value="Previous conversation", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text", - return_value="Context summary", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history", - ) as mock_save, - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" - ) as mock_cache_config, - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" - ) as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = mock_user - - completion = await retriever.get_completion( - "test query", session_id="test_session", max_iter=1 - ) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - mock_save.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_completion_with_save_interaction(mock_edge): - """Test get_completion with save_interaction enabled.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - mock_graph_engine.add_edges = AsyncMock() - - retriever = GraphCompletionCotRetriever(save_interaction=True) - - mock_node1 = MagicMock() - mock_node2 = MagicMock() - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", - return_value="Rendered prompt", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", - return_value="System prompt", - ), - patch.object( - LLMGateway, - "acreate_structured_output", - new_callable=AsyncMock, - side_effect=["validation_result", "followup_question"], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", - side_effect=[ - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.add_data_points", - ) as mock_add_data, - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - # Pass context so save_interaction condition is met - completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) - - assert isinstance(completion, list) - assert len(completion) == 1 - mock_add_data.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_completion_with_response_model(mock_edge): - """Test get_completion with custom response model.""" - from pydantic import BaseModel - - class TestModel(BaseModel): - answer: str - - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionCotRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value=TestModel(answer="Test answer"), - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion( - "test query", response_model=TestModel, max_iter=1 - ) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert isinstance(completion[0], TestModel) - - -@pytest.mark.asyncio -async def test_get_completion_with_session_no_user_id(mock_edge): - """Test get_completion with session config but no user ID.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionCotRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" - ) as mock_cache_config, - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" - ) as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = None # No user - - completion = await retriever.get_completion("test query", max_iter=1) - - assert isinstance(completion, list) - assert len(completion) == 1 - - -@pytest.mark.asyncio -async def test_get_completion_with_save_interaction_no_context(mock_edge): - """Test get_completion with save_interaction but no context provided.""" - retriever = GraphCompletionCotRetriever(save_interaction=True) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", - return_value="Generated answer", - ), - patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", - return_value="Rendered prompt", - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", - return_value="System prompt", - ), - patch.object( - LLMGateway, - "acreate_structured_output", - new_callable=AsyncMock, - side_effect=["validation_result", "followup_question"], - ), - patch( - "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context=None, max_iter=1) - - assert isinstance(completion, list) - assert len(completion) == 1 - - -@pytest.mark.asyncio -async def test_as_answer_text_with_typeerror(): - """Test _as_answer_text handles TypeError when json.dumps fails.""" - non_serializable = {1, 2, 3} - - result = _as_answer_text(non_serializable) - - assert isinstance(result, str) - assert result == str(non_serializable) - - -@pytest.mark.asyncio -async def test_as_answer_text_with_string(): - """Test _as_answer_text with string input.""" - result = _as_answer_text("test string") - assert result == "test string" - - -@pytest.mark.asyncio -async def test_as_answer_text_with_dict(): - """Test _as_answer_text with dictionary input.""" - test_dict = {"key": "value", "number": 42} - result = _as_answer_text(test_dict) - assert isinstance(result, str) - assert "key" in result - assert "value" in result - - -@pytest.mark.asyncio -async def test_as_answer_text_with_basemodel(): - """Test _as_answer_text with Pydantic BaseModel input.""" - from pydantic import BaseModel - - class TestModel(BaseModel): - answer: str - - test_model = TestModel(answer="test answer") - result = _as_answer_text(test_model) - - assert isinstance(result, str) - assert "[Structured Response]" in result - assert "test answer" in result diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py deleted file mode 100644 index c22f30fd0..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ /dev/null @@ -1,648 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from uuid import UUID - -from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge - - -@pytest.fixture -def mock_edge(): - """Create a mock edge.""" - edge = MagicMock(spec=Edge) - return edge - - -@pytest.mark.asyncio -async def test_get_triplets_success(mock_edge): - """Test successful retrieval of triplets.""" - retriever = GraphCompletionRetriever(top_k=5) - - with patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ) as mock_search: - triplets = await retriever.get_triplets("test query") - - assert len(triplets) == 1 - assert triplets[0] == mock_edge - mock_search.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_triplets_empty_results(): - """Test that empty list is returned when no triplets are found.""" - retriever = GraphCompletionRetriever() - - with patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[], - ): - triplets = await retriever.get_triplets("test query") - - assert triplets == [] - - -@pytest.mark.asyncio -async def test_get_triplets_top_k_parameter(): - """Test that top_k parameter is passed to brute_force_triplet_search.""" - retriever = GraphCompletionRetriever(top_k=10) - - with patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[], - ) as mock_search: - await retriever.get_triplets("test query") - - call_kwargs = mock_search.call_args[1] - assert call_kwargs["top_k"] == 10 - - -@pytest.mark.asyncio -async def test_get_context_success(mock_edge): - """Test successful retrieval of context.""" - retriever = GraphCompletionRetriever() - - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - ): - context = await retriever.get_context("test query") - - assert isinstance(context, list) - assert len(context) == 1 - assert context[0] == mock_edge - - -@pytest.mark.asyncio -async def test_get_context_empty_results(): - """Test that empty list is returned when no context is found.""" - retriever = GraphCompletionRetriever() - - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[], - ), - ): - context = await retriever.get_context("test query") - - assert context == [] - - -@pytest.mark.asyncio -async def test_get_context_empty_graph(): - """Test that empty list is returned when graph is empty.""" - retriever = GraphCompletionRetriever() - - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=True) - - with patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ): - context = await retriever.get_context("test query") - - assert context == [] - - -@pytest.mark.asyncio -async def test_resolve_edges_to_text(mock_edge): - """Test resolve_edges_to_text method.""" - retriever = GraphCompletionRetriever() - - with patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved text", - ) as mock_resolve: - result = await retriever.resolve_edges_to_text([mock_edge]) - - assert result == "Resolved text" - mock_resolve.assert_awaited_once_with([mock_edge]) - - -@pytest.mark.asyncio -async def test_init_defaults(): - """Test GraphCompletionRetriever initialization with defaults.""" - retriever = GraphCompletionRetriever() - - assert retriever.top_k == 5 - assert retriever.user_prompt_path == "graph_context_for_question.txt" - assert retriever.system_prompt_path == "answer_simple_question.txt" - assert retriever.node_type is None - assert retriever.node_name is None - - -@pytest.mark.asyncio -async def test_init_custom_params(): - """Test GraphCompletionRetriever initialization with custom parameters.""" - retriever = GraphCompletionRetriever( - top_k=10, - user_prompt_path="custom_user.txt", - system_prompt_path="custom_system.txt", - system_prompt="Custom prompt", - node_type=str, - node_name=["node1"], - save_interaction=True, - wide_search_top_k=200, - triplet_distance_penalty=5.0, - ) - - assert retriever.top_k == 10 - assert retriever.user_prompt_path == "custom_user.txt" - assert retriever.system_prompt_path == "custom_system.txt" - assert retriever.system_prompt == "Custom prompt" - assert retriever.node_type is str - assert retriever.node_name == ["node1"] - assert retriever.save_interaction is True - assert retriever.wide_search_top_k == 200 - assert retriever.triplet_distance_penalty == 5.0 - - -@pytest.mark.asyncio -async def test_init_none_top_k(): - """Test GraphCompletionRetriever initialization with None top_k.""" - retriever = GraphCompletionRetriever(top_k=None) - - assert retriever.top_k == 5 # None defaults to 5 - - -@pytest.mark.asyncio -async def test_convert_retrieved_objects_to_context(mock_edge): - """Test convert_retrieved_objects_to_context method.""" - retriever = GraphCompletionRetriever() - - with patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved text", - ) as mock_resolve: - result = await retriever.convert_retrieved_objects_to_context([mock_edge]) - - assert result == "Resolved text" - mock_resolve.assert_awaited_once_with([mock_edge]) - - -@pytest.mark.asyncio -async def test_get_completion_without_context(mock_edge): - """Test get_completion retrieves context when not provided.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_provided_context(mock_edge): - """Test get_completion uses provided context.""" - retriever = GraphCompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context=[mock_edge]) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_session(mock_edge): - """Test get_completion with session caching enabled.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionRetriever() - - mock_user = MagicMock() - mock_user.id = "test-user-id" - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_conversation_history", - return_value="Previous conversation", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.summarize_text", - return_value="Context summary", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.save_conversation_history", - ) as mock_save, - patch( - "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" - ) as mock_cache_config, - patch( - "cognee.modules.retrieval.graph_completion_retriever.session_user" - ) as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = mock_user - - completion = await retriever.get_completion("test query", session_id="test_session") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - mock_save.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_completion_with_response_model(mock_edge): - """Test get_completion with custom response model.""" - from pydantic import BaseModel - - class TestModel(BaseModel): - answer: str - - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.generate_completion", - return_value=TestModel(answer="Test answer"), - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", response_model=TestModel) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert isinstance(completion[0], TestModel) - - -@pytest.mark.asyncio -async def test_get_completion_empty_context(mock_edge): - """Test get_completion with empty context.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query") - - assert isinstance(completion, list) - assert len(completion) == 1 - - -@pytest.mark.asyncio -async def test_save_qa(mock_edge): - """Test save_qa method.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.add_edges = AsyncMock() - - retriever = GraphCompletionRetriever() - - mock_node1 = MagicMock() - mock_node2 = MagicMock() - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", - side_effect=["uuid1", "uuid2"], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.add_data_points", - ) as mock_add_data, - ): - await retriever.save_qa( - question="Test question", - answer="Test answer", - context="Test context", - triplets=[mock_edge], - ) - - mock_add_data.assert_awaited_once() - mock_graph_engine.add_edges.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_save_qa_no_triplet_ids(mock_edge): - """Test save_qa when triplets have no extractable IDs.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.add_edges = AsyncMock() - - retriever = GraphCompletionRetriever() - - mock_node1 = MagicMock() - mock_node2 = MagicMock() - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", - return_value=None, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.add_data_points", - ) as mock_add_data, - ): - await retriever.save_qa( - question="Test question", - answer="Test answer", - context="Test context", - triplets=[mock_edge], - ) - - mock_add_data.assert_awaited_once() - mock_graph_engine.add_edges.assert_not_called() - - -@pytest.mark.asyncio -async def test_save_qa_empty_triplets(): - """Test save_qa with empty triplets list.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.add_edges = AsyncMock() - - retriever = GraphCompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.add_data_points", - ) as mock_add_data, - ): - await retriever.save_qa( - question="Test question", - answer="Test answer", - context="Test context", - triplets=[], - ) - - mock_add_data.assert_awaited_once() - mock_graph_engine.add_edges.assert_not_called() - - -@pytest.mark.asyncio -async def test_get_completion_with_save_interaction_no_completion(mock_edge): - """Test get_completion with save_interaction but no completion.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionRetriever(save_interaction=True) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.generate_completion", - return_value=None, # No completion - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] is None - - -@pytest.mark.asyncio -async def test_get_completion_with_save_interaction_no_context(mock_edge): - """Test get_completion with save_interaction but no context provided.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionRetriever(save_interaction=True) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context=None) - - assert isinstance(completion, list) - assert len(completion) == 1 - - -@pytest.mark.asyncio -async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge): - """Test get_completion with save_interaction when all conditions are met (line 216).""" - mock_graph_engine = AsyncMock() - mock_graph_engine.is_empty = AsyncMock(return_value=False) - - retriever = GraphCompletionRetriever(save_interaction=True) - - mock_node1 = MagicMock() - mock_node2 = MagicMock() - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", - return_value=[mock_edge], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", - return_value="Resolved context", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", - side_effect=[ - UUID("550e8400-e29b-41d4-a716-446655440000"), - UUID("550e8400-e29b-41d4-a716-446655440001"), - ], - ), - patch( - "cognee.modules.retrieval.graph_completion_retriever.add_data_points", - ) as mock_add_data, - patch( - "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" - ) as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context=[mock_edge]) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - mock_add_data.assert_awaited_once() diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py deleted file mode 100644 index e998d419d..000000000 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ /dev/null @@ -1,321 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock - -from cognee.modules.retrieval.completion_retriever import CompletionRetriever -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError - - -@pytest.fixture -def mock_vector_engine(): - """Create a mock vector engine.""" - engine = AsyncMock() - engine.search = AsyncMock() - return engine - - -@pytest.mark.asyncio -async def test_get_context_success(mock_vector_engine): - """Test successful retrieval of context.""" - mock_result1 = MagicMock() - mock_result1.payload = {"text": "Steve Rodger"} - mock_result2 = MagicMock() - mock_result2.payload = {"text": "Mike Broski"} - - mock_vector_engine.search.return_value = [mock_result1, mock_result2] - - retriever = CompletionRetriever(top_k=2) - - with patch( - "cognee.modules.retrieval.completion_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "Steve Rodger\nMike Broski" - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) - - -@pytest.mark.asyncio -async def test_get_context_collection_not_found_error(mock_vector_engine): - """Test that CollectionNotFoundError is converted to NoDataError.""" - mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") - - retriever = CompletionRetriever() - - with patch( - "cognee.modules.retrieval.completion_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="No data found"): - await retriever.get_context("test query") - - -@pytest.mark.asyncio -async def test_get_context_empty_results(mock_vector_engine): - """Test that empty string is returned when no chunks are found.""" - mock_vector_engine.search.return_value = [] - - retriever = CompletionRetriever() - - with patch( - "cognee.modules.retrieval.completion_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "" - - -@pytest.mark.asyncio -async def test_get_context_top_k_limit(mock_vector_engine): - """Test that top_k parameter limits the number of results.""" - mock_results = [MagicMock() for _ in range(2)] - for i, result in enumerate(mock_results): - result.payload = {"text": f"Chunk {i}"} - - mock_vector_engine.search.return_value = mock_results - - retriever = CompletionRetriever(top_k=2) - - with patch( - "cognee.modules.retrieval.completion_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "Chunk 0\nChunk 1" - mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) - - -@pytest.mark.asyncio -async def test_get_context_single_chunk(mock_vector_engine): - """Test get_context with single chunk result.""" - mock_result = MagicMock() - mock_result.payload = {"text": "Single chunk text"} - mock_vector_engine.search.return_value = [mock_result] - - retriever = CompletionRetriever() - - with patch( - "cognee.modules.retrieval.completion_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "Single chunk text" - - -@pytest.mark.asyncio -async def test_get_completion_without_session(mock_vector_engine): - """Test get_completion without session caching.""" - mock_result = MagicMock() - mock_result.payload = {"text": "Chunk text"} - mock_vector_engine.search.return_value = [mock_result] - - retriever = CompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.completion_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_provided_context(mock_vector_engine): - """Test get_completion with provided context.""" - retriever = CompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context="Provided context") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_session(mock_vector_engine): - """Test get_completion with session caching enabled.""" - mock_result = MagicMock() - mock_result.payload = {"text": "Chunk text"} - mock_vector_engine.search.return_value = [mock_result] - - retriever = CompletionRetriever() - - mock_user = MagicMock() - mock_user.id = "test-user-id" - - with ( - patch( - "cognee.modules.retrieval.completion_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.completion_retriever.get_conversation_history", - return_value="Previous conversation", - ), - patch( - "cognee.modules.retrieval.completion_retriever.summarize_text", - return_value="Context summary", - ), - patch( - "cognee.modules.retrieval.completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.completion_retriever.save_conversation_history", - ) as mock_save, - patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, - patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = mock_user - - completion = await retriever.get_completion("test query", session_id="test_session") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - mock_save.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_completion_with_session_no_user_id(mock_vector_engine): - """Test get_completion with session config but no user ID.""" - mock_result = MagicMock() - mock_result.payload = {"text": "Chunk text"} - mock_vector_engine.search.return_value = [mock_result] - - retriever = CompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.completion_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.completion_retriever.generate_completion", - return_value="Generated answer", - ), - patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, - patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = None # No user - - completion = await retriever.get_completion("test query") - - assert isinstance(completion, list) - assert len(completion) == 1 - - -@pytest.mark.asyncio -async def test_get_completion_with_response_model(mock_vector_engine): - """Test get_completion with custom response model.""" - from pydantic import BaseModel - - class TestModel(BaseModel): - answer: str - - mock_result = MagicMock() - mock_result.payload = {"text": "Chunk text"} - mock_vector_engine.search.return_value = [mock_result] - - retriever = CompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.completion_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.completion_retriever.generate_completion", - return_value=TestModel(answer="Test answer"), - ), - patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", response_model=TestModel) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert isinstance(completion[0], TestModel) - - -@pytest.mark.asyncio -async def test_init_defaults(): - """Test CompletionRetriever initialization with defaults.""" - retriever = CompletionRetriever() - - assert retriever.user_prompt_path == "context_for_question.txt" - assert retriever.system_prompt_path == "answer_simple_question.txt" - assert retriever.top_k == 1 - assert retriever.system_prompt is None - - -@pytest.mark.asyncio -async def test_init_custom_params(): - """Test CompletionRetriever initialization with custom parameters.""" - retriever = CompletionRetriever( - user_prompt_path="custom_user.txt", - system_prompt_path="custom_system.txt", - system_prompt="Custom prompt", - top_k=10, - ) - - assert retriever.user_prompt_path == "custom_user.txt" - assert retriever.system_prompt_path == "custom_system.txt" - assert retriever.system_prompt == "Custom prompt" - assert retriever.top_k == 10 - - -@pytest.mark.asyncio -async def test_get_context_missing_text_key(mock_vector_engine): - """Test get_context handles missing text key in payload.""" - mock_result = MagicMock() - mock_result.payload = {"other_key": "value"} - - mock_vector_engine.search.return_value = [mock_result] - - retriever = CompletionRetriever() - - with patch( - "cognee.modules.retrieval.completion_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(KeyError): - await retriever.get_context("test query") diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py deleted file mode 100644 index e552ac74a..000000000 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ /dev/null @@ -1,193 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock - -from cognee.modules.retrieval.summaries_retriever import SummariesRetriever -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError - - -@pytest.fixture -def mock_vector_engine(): - """Create a mock vector engine.""" - engine = AsyncMock() - engine.search = AsyncMock() - return engine - - -@pytest.mark.asyncio -async def test_get_context_success(mock_vector_engine): - """Test successful retrieval of summary context.""" - mock_result1 = MagicMock() - mock_result1.payload = {"text": "S.R.", "made_from": "chunk1"} - mock_result2 = MagicMock() - mock_result2.payload = {"text": "M.B.", "made_from": "chunk2"} - - mock_vector_engine.search.return_value = [mock_result1, mock_result2] - - retriever = SummariesRetriever(top_k=5) - - with patch( - "cognee.modules.retrieval.summaries_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert len(context) == 2 - assert context[0]["text"] == "S.R." - assert context[1]["text"] == "M.B." - mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5) - - -@pytest.mark.asyncio -async def test_get_context_collection_not_found_error(mock_vector_engine): - """Test that CollectionNotFoundError is converted to NoDataError.""" - mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") - - retriever = SummariesRetriever() - - with patch( - "cognee.modules.retrieval.summaries_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="No data found"): - await retriever.get_context("test query") - - -@pytest.mark.asyncio -async def test_get_context_empty_results(mock_vector_engine): - """Test that empty list is returned when no summaries are found.""" - mock_vector_engine.search.return_value = [] - - retriever = SummariesRetriever() - - with patch( - "cognee.modules.retrieval.summaries_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == [] - - -@pytest.mark.asyncio -async def test_get_context_top_k_limit(mock_vector_engine): - """Test that top_k parameter limits the number of results.""" - mock_results = [MagicMock() for _ in range(3)] - for i, result in enumerate(mock_results): - result.payload = {"text": f"Summary {i}"} - - mock_vector_engine.search.return_value = mock_results - - retriever = SummariesRetriever(top_k=3) - - with patch( - "cognee.modules.retrieval.summaries_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert len(context) == 3 - mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3) - - -@pytest.mark.asyncio -async def test_get_completion_with_context(mock_vector_engine): - """Test get_completion returns provided context.""" - retriever = SummariesRetriever() - - provided_context = [{"text": "S.R."}, {"text": "M.B."}] - completion = await retriever.get_completion("test query", context=provided_context) - - assert completion == provided_context - - -@pytest.mark.asyncio -async def test_get_completion_without_context(mock_vector_engine): - """Test get_completion retrieves context when not provided.""" - mock_result = MagicMock() - mock_result.payload = {"text": "S.R."} - mock_vector_engine.search.return_value = [mock_result] - - retriever = SummariesRetriever() - - with patch( - "cognee.modules.retrieval.summaries_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - completion = await retriever.get_completion("test query") - - assert len(completion) == 1 - assert completion[0]["text"] == "S.R." - - -@pytest.mark.asyncio -async def test_init_defaults(): - """Test SummariesRetriever initialization with defaults.""" - retriever = SummariesRetriever() - - assert retriever.top_k == 5 - - -@pytest.mark.asyncio -async def test_init_custom_top_k(): - """Test SummariesRetriever initialization with custom top_k.""" - retriever = SummariesRetriever(top_k=10) - - assert retriever.top_k == 10 - - -@pytest.mark.asyncio -async def test_get_context_empty_payload(mock_vector_engine): - """Test get_context handles empty payload.""" - mock_result = MagicMock() - mock_result.payload = {} - - mock_vector_engine.search.return_value = [mock_result] - - retriever = SummariesRetriever() - - with patch( - "cognee.modules.retrieval.summaries_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert len(context) == 1 - assert context[0] == {} - - -@pytest.mark.asyncio -async def test_get_completion_with_session_id(mock_vector_engine): - """Test get_completion with session_id parameter.""" - mock_result = MagicMock() - mock_result.payload = {"text": "S.R."} - mock_vector_engine.search.return_value = [mock_result] - - retriever = SummariesRetriever() - - with patch( - "cognee.modules.retrieval.summaries_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - completion = await retriever.get_completion("test query", session_id="test_session") - - assert len(completion) == 1 - assert completion[0]["text"] == "S.R." - - -@pytest.mark.asyncio -async def test_get_completion_with_kwargs(mock_vector_engine): - """Test get_completion accepts additional kwargs.""" - mock_result = MagicMock() - mock_result.payload = {"text": "S.R."} - mock_vector_engine.search.return_value = [mock_result] - - retriever = SummariesRetriever() - - with patch( - "cognee.modules.retrieval.summaries_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - completion = await retriever.get_completion("test query", extra_param="value") - - assert len(completion) == 1 diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py deleted file mode 100644 index 1d2f4c84d..000000000 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ /dev/null @@ -1,705 +0,0 @@ -from types import SimpleNamespace -import pytest -import os -from unittest.mock import AsyncMock, patch, MagicMock -from datetime import datetime - -from cognee.modules.retrieval.temporal_retriever import TemporalRetriever -from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp -from cognee.infrastructure.llm import LLMGateway - - -# Test TemporalRetriever initialization defaults and overrides -def test_init_defaults_and_overrides(): - tr = TemporalRetriever() - assert tr.top_k == 5 - assert tr.user_prompt_path == "graph_context_for_question.txt" - assert tr.system_prompt_path == "answer_simple_question.txt" - assert tr.time_extraction_prompt_path == "extract_query_time.txt" - - tr2 = TemporalRetriever( - top_k=3, - user_prompt_path="u.txt", - system_prompt_path="s.txt", - time_extraction_prompt_path="t.txt", - ) - assert tr2.top_k == 3 - assert tr2.user_prompt_path == "u.txt" - assert tr2.system_prompt_path == "s.txt" - assert tr2.time_extraction_prompt_path == "t.txt" - - -# Test descriptions_to_string with basic and empty results -def test_descriptions_to_string_basic_and_empty(): - tr = TemporalRetriever() - - results = [ - {"description": " First "}, - {"nope": "no description"}, - {"description": "Second"}, - {"description": ""}, - {"description": " Third line "}, - ] - - s = tr.descriptions_to_string(results) - assert s == "First\n#####################\nSecond\n#####################\nThird line" - - assert tr.descriptions_to_string([]) == "" - - -# Test filter_top_k_events sorts and limits correctly -@pytest.mark.asyncio -async def test_filter_top_k_events_sorts_and_limits(): - tr = TemporalRetriever(top_k=2) - - relevant_events = [ - { - "events": [ - {"id": "e1", "description": "E1"}, - {"id": "e2", "description": "E2"}, - {"id": "e3", "description": "E3 - not in vector results"}, - ] - } - ] - - scored_results = [ - SimpleNamespace(payload={"id": "e2"}, score=0.10), - SimpleNamespace(payload={"id": "e1"}, score=0.20), - ] - - top = await tr.filter_top_k_events(relevant_events, scored_results) - - assert [e["id"] for e in top] == ["e2", "e1"] - assert all("score" in e for e in top) - assert top[0]["score"] == 0.10 - assert top[1]["score"] == 0.20 - - -# Test filter_top_k_events handles unknown ids as infinite scores -@pytest.mark.asyncio -async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k(): - tr = TemporalRetriever(top_k=2) - - relevant_events = [ - { - "events": [ - {"id": "known1", "description": "Known 1"}, - {"id": "unknown", "description": "Unknown"}, - {"id": "known2", "description": "Known 2"}, - ] - } - ] - - scored_results = [ - SimpleNamespace(payload={"id": "known2"}, score=0.05), - SimpleNamespace(payload={"id": "known1"}, score=0.50), - ] - - top = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in top] == ["known2", "known1"] - assert all(e["score"] != float("inf") for e in top) - - -# Test descriptions_to_string with unicode and newlines -def test_descriptions_to_string_unicode_and_newlines(): - tr = TemporalRetriever() - results = [ - {"description": "Line A\nwith newline"}, - {"description": "This is a description"}, - ] - s = tr.descriptions_to_string(results) - assert "Line A\nwith newline" in s - assert "This is a description" in s - assert s.count("#####################") == 1 - - -# Test filter_top_k_events when top_k is larger than available events -@pytest.mark.asyncio -async def test_filter_top_k_events_limits_when_top_k_exceeds_events(): - tr = TemporalRetriever(top_k=10) - relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}] - scored_results = [ - SimpleNamespace(payload={"id": "a"}, score=0.1), - SimpleNamespace(payload={"id": "b"}, score=0.2), - ] - out = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in out] == ["a", "b"] - - -# Test filter_top_k_events when scored_results is empty -@pytest.mark.asyncio -async def test_filter_top_k_events_handles_empty_scored_results(): - tr = TemporalRetriever(top_k=2) - relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}] - scored_results = [] - out = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in out] == ["x", "y"] - assert all(e["score"] == float("inf") for e in out) - - -# Test filter_top_k_events error handling for missing structure -@pytest.mark.asyncio -async def test_filter_top_k_events_error_handling(): - tr = TemporalRetriever(top_k=2) - with pytest.raises((KeyError, TypeError)): - await tr.filter_top_k_events([{}], []) - - -@pytest.fixture -def mock_graph_engine(): - """Create a mock graph engine.""" - engine = AsyncMock() - engine.collect_time_ids = AsyncMock() - engine.collect_events = AsyncMock() - return engine - - -@pytest.fixture -def mock_vector_engine(): - """Create a mock vector engine.""" - engine = AsyncMock() - engine.embedding_engine = AsyncMock() - engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - engine.search = AsyncMock() - return engine - - -@pytest.mark.asyncio -async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine): - """Test get_context when time range is extracted from query.""" - retriever = TemporalRetriever(top_k=5) - - mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"] - mock_graph_engine.collect_events.return_value = [ - { - "events": [ - {"id": "e1", "description": "Event 1"}, - {"id": "e2", "description": "Event 2"}, - ] - } - ] - - mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05) - mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10) - mock_vector_engine.search.return_value = [mock_result1, mock_result2] - - with ( - patch.object( - retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - ): - context = await retriever.get_context("What happened in 2024?") - - assert isinstance(context, str) - assert len(context) > 0 - assert "Event" in context - - -@pytest.mark.asyncio -async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine): - """Test get_context falls back to triplets when no time is extracted.""" - retriever = TemporalRetriever() - - with ( - patch( - "cognee.modules.retrieval.temporal_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch.object( - retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] - ) as mock_get_triplets, - patch.object( - retriever, "resolve_edges_to_text", return_value="triplet text" - ) as mock_resolve, - ): - - async def mock_extract_time(query): - return None, None - - retriever.extract_time_from_query = mock_extract_time - - context = await retriever.get_context("test query") - - assert context == "triplet text" - mock_get_triplets.assert_awaited_once_with("test query") - mock_resolve.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_context_no_events_found(mock_graph_engine): - """Test get_context falls back to triplets when no events are found.""" - retriever = TemporalRetriever() - - mock_graph_engine.collect_time_ids.return_value = [] - - with ( - patch( - "cognee.modules.retrieval.temporal_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch.object( - retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] - ) as mock_get_triplets, - patch.object( - retriever, "resolve_edges_to_text", return_value="triplet text" - ) as mock_resolve, - ): - - async def mock_extract_time(query): - return "2024-01-01", "2024-12-31" - - retriever.extract_time_from_query = mock_extract_time - - context = await retriever.get_context("test query") - - assert context == "triplet text" - mock_get_triplets.assert_awaited_once_with("test query") - mock_resolve.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine): - """Test get_context with only time_from.""" - retriever = TemporalRetriever(top_k=5) - - mock_graph_engine.collect_time_ids.return_value = ["e1"] - mock_graph_engine.collect_events.return_value = [ - { - "events": [ - {"id": "e1", "description": "Event 1"}, - ] - } - ] - - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) - mock_vector_engine.search.return_value = [mock_result] - - with ( - patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)), - patch( - "cognee.modules.retrieval.temporal_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - ): - context = await retriever.get_context("What happened after 2024?") - - assert isinstance(context, str) - assert "Event 1" in context - - -@pytest.mark.asyncio -async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine): - """Test get_context with only time_to.""" - retriever = TemporalRetriever(top_k=5) - - mock_graph_engine.collect_time_ids.return_value = ["e1"] - mock_graph_engine.collect_events.return_value = [ - { - "events": [ - {"id": "e1", "description": "Event 1"}, - ] - } - ] - - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) - mock_vector_engine.search.return_value = [mock_result] - - with ( - patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")), - patch( - "cognee.modules.retrieval.temporal_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - ): - context = await retriever.get_context("What happened before 2024?") - - assert isinstance(context, str) - assert "Event 1" in context - - -@pytest.mark.asyncio -async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine): - """Test get_completion retrieves context when not provided.""" - retriever = TemporalRetriever() - - mock_graph_engine.collect_time_ids.return_value = ["e1"] - mock_graph_engine.collect_events.return_value = [ - { - "events": [ - {"id": "e1", "description": "Event 1"}, - ] - } - ] - - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) - mock_vector_engine.search.return_value = [mock_result] - - with ( - patch.object( - retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.generate_completion", - return_value="Generated answer", - ), - patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("What happened in 2024?") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_provided_context(): - """Test get_completion uses provided context.""" - retriever = TemporalRetriever() - - with ( - patch( - "cognee.modules.retrieval.temporal_retriever.generate_completion", - return_value="Generated answer", - ), - patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context="Provided context") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine): - """Test get_completion with session caching enabled.""" - retriever = TemporalRetriever() - - mock_graph_engine.collect_time_ids.return_value = ["e1"] - mock_graph_engine.collect_events.return_value = [ - { - "events": [ - {"id": "e1", "description": "Event 1"}, - ] - } - ] - - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) - mock_vector_engine.search.return_value = [mock_result] - - mock_user = MagicMock() - mock_user.id = "test-user-id" - - with ( - patch.object( - retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_conversation_history", - return_value="Previous conversation", - ), - patch( - "cognee.modules.retrieval.temporal_retriever.summarize_text", - return_value="Context summary", - ), - patch( - "cognee.modules.retrieval.temporal_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.temporal_retriever.save_conversation_history", - ) as mock_save, - patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, - patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = mock_user - - completion = await retriever.get_completion( - "What happened in 2024?", session_id="test_session" - ) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - mock_save.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine): - """Test get_completion with session config but no user ID.""" - retriever = TemporalRetriever() - - mock_graph_engine.collect_time_ids.return_value = ["e1"] - mock_graph_engine.collect_events.return_value = [ - { - "events": [ - {"id": "e1", "description": "Event 1"}, - ] - } - ] - - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) - mock_vector_engine.search.return_value = [mock_result] - - with ( - patch.object( - retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.generate_completion", - return_value="Generated answer", - ), - patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, - patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = None # No user - - completion = await retriever.get_completion("What happened in 2024?") - - assert isinstance(completion, list) - assert len(completion) == 1 - - -@pytest.mark.asyncio -async def test_get_completion_context_retrieved_but_empty(mock_graph_engine): - """Test get_completion when get_context returns empty string.""" - retriever = TemporalRetriever() - - with ( - patch.object( - retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_vector_engine", - ) as mock_get_vector, - patch.object(retriever, "filter_top_k_events", return_value=[]), - ): - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - mock_get_vector.return_value = mock_vector_engine - - mock_graph_engine.collect_time_ids.return_value = ["e1"] - mock_graph_engine.collect_events.return_value = [ - { - "events": [ - {"id": "e1", "description": ""}, - ] - } - ] - - with pytest.raises((UnboundLocalError, NameError)): - await retriever.get_completion("test query") - - -@pytest.mark.asyncio -async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine): - """Test get_completion with custom response model.""" - from pydantic import BaseModel - - class TestModel(BaseModel): - answer: str - - retriever = TemporalRetriever() - - mock_graph_engine.collect_time_ids.return_value = ["e1"] - mock_graph_engine.collect_events.return_value = [ - { - "events": [ - {"id": "e1", "description": "Event 1"}, - ] - } - ] - - mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) - mock_vector_engine.search.return_value = [mock_result] - - with ( - patch.object( - retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.temporal_retriever.generate_completion", - return_value=TestModel(answer="Test answer"), - ), - patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion( - "What happened in 2024?", response_model=TestModel - ) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert isinstance(completion[0], TestModel) - - -@pytest.mark.asyncio -async def test_extract_time_from_query_relative_path(): - """Test extract_time_from_query with relative prompt path.""" - retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") - - mock_timestamp_from = Timestamp(year=2024, month=1, day=1) - mock_timestamp_to = Timestamp(year=2024, month=12, day=31) - mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) - - with ( - patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), - patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, - patch( - "cognee.modules.retrieval.temporal_retriever.render_prompt", - return_value="System prompt", - ), - patch.object( - LLMGateway, - "acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_interval, - ), - ): - mock_datetime.now.return_value.strftime.return_value = "11-12-2024" - - time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") - - assert time_from == mock_timestamp_from - assert time_to == mock_timestamp_to - - -@pytest.mark.asyncio -async def test_extract_time_from_query_absolute_path(): - """Test extract_time_from_query with absolute prompt path.""" - retriever = TemporalRetriever( - time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt" - ) - - mock_timestamp_from = Timestamp(year=2024, month=1, day=1) - mock_timestamp_to = Timestamp(year=2024, month=12, day=31) - mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) - - with ( - patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True), - patch( - "cognee.modules.retrieval.temporal_retriever.os.path.dirname", - return_value="/absolute/path/to", - ), - patch( - "cognee.modules.retrieval.temporal_retriever.os.path.basename", - return_value="extract_query_time.txt", - ), - patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, - patch( - "cognee.modules.retrieval.temporal_retriever.render_prompt", - return_value="System prompt", - ), - patch.object( - LLMGateway, - "acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_interval, - ), - ): - mock_datetime.now.return_value.strftime.return_value = "11-12-2024" - - time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") - - assert time_from == mock_timestamp_from - assert time_to == mock_timestamp_to - - -@pytest.mark.asyncio -async def test_extract_time_from_query_with_none_values(): - """Test extract_time_from_query when interval has None values.""" - retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") - - mock_interval = QueryInterval(starts_at=None, ends_at=None) - - with ( - patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), - patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, - patch( - "cognee.modules.retrieval.temporal_retriever.render_prompt", - return_value="System prompt", - ), - patch.object( - LLMGateway, - "acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_interval, - ), - ): - mock_datetime.now.return_value.strftime.return_value = "11-12-2024" - - time_from, time_to = await retriever.extract_time_from_query("What happened?") - - assert time_from is None - assert time_to is None diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py deleted file mode 100644 index b7cbe08d7..000000000 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ /dev/null @@ -1,817 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock - -from cognee.modules.retrieval.utils.brute_force_triplet_search import ( - brute_force_triplet_search, - get_memory_fragment, - format_triplets, -) -from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph -from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError -from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError - - -class MockScoredResult: - """Mock class for vector search results.""" - - def __init__(self, id, score, payload=None): - self.id = id - self.score = score - self.payload = payload or {} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_empty_query(): - """Test that empty query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): - await brute_force_triplet_search(query="") - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_none_query(): - """Test that None query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): - await brute_force_triplet_search(query=None) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_negative_top_k(): - """Test that negative top_k raises ValueError.""" - with pytest.raises(ValueError, match="top_k must be a positive integer."): - await brute_force_triplet_search(query="test query", top_k=-1) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_zero_top_k(): - """Test that zero top_k raises ValueError.""" - with pytest.raises(ValueError, match="top_k must be a positive integer."): - await brute_force_triplet_search(query="test query", top_k=0) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_limit_global_search(): - """Test that wide_search_limit is applied for global search (node_name=None).""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search( - query="test", - node_name=None, # Global search - wide_search_top_k=75, - ) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] == 75 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): - """Test that wide_search_limit is None for filtered search (node_name provided).""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search( - query="test", - node_name=["Node1"], - wide_search_top_k=50, - ) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] is None - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_default(): - """Test that wide_search_top_k defaults to 100.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", node_name=None) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] == 100 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_default_collections(): - """Test that default collections are used when none provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test") - - expected_collections = [ - "Entity_name", - "TextSummary_text", - "EntityType_name", - "DocumentChunk_text", - "EdgeType_relationship_name", - ] - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert call_collections == expected_collections - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_custom_collections(): - """Test that custom collections are used when provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - custom_collections = ["CustomCol1", "CustomCol2"] - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", collections=custom_collections) - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_always_includes_edge_collection(): - """Test that EdgeType_relationship_name is always searched even when not in collections.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - collections_without_edge = ["Entity_name", "TextSummary_text"] - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", collections=collections_without_edge) - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert "EdgeType_relationship_name" in call_collections - assert set(call_collections) == set(collections_without_edge) | { - "EdgeType_relationship_name" - } - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_all_collections_empty(): - """Test that empty list is returned when all collections return no results.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - results = await brute_force_triplet_search(query="test") - assert results == [] - - -# Tests for query embedding - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_embeds_query(): - """Test that query is embedded before searching.""" - query_text = "test query" - expected_vector = [0.1, 0.2, 0.3] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query=query_text) - - mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text]) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["query_vector"] == expected_vector - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_extracts_node_ids_global_search(): - """Test that node IDs are extracted from search results for global search.""" - scored_results = [ - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - MockScoredResult("node3", 0.92), - ] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=scored_results) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_reuses_provided_fragment(): - """Test that provided memory fragment is reused instead of creating new one.""" - provided_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment" - ) as mock_get_fragment, - ): - await brute_force_triplet_search( - query="test", - memory_fragment=provided_fragment, - node_name=["node"], - ) - - mock_get_fragment.assert_not_called() - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): - """Test that memory fragment is created when not provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment, - ): - await brute_force_triplet_search(query="test", node_name=["node"]) - - mock_get_fragment.assert_called_once() - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation(): - """Test that custom top_k is passed to importance calculation.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ), - ): - custom_top_k = 15 - await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) - - mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) - - -@pytest.mark.asyncio -async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): - """Test that get_memory_fragment returns empty graph when entity not found (line 85).""" - mock_graph_engine = AsyncMock() - - # Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called - mock_fragment = MagicMock(spec=CogneeGraph) - mock_fragment.project_graph_from_db = AsyncMock( - side_effect=EntityNotFoundError("Entity not found") - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph", - return_value=mock_fragment, - ), - ): - result = await get_memory_fragment() - - # Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85) - assert result == mock_fragment - mock_fragment.project_graph_from_db.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_memory_fragment_returns_empty_graph_on_error(): - """Test that get_memory_fragment returns empty graph on generic error.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error")) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, - ): - fragment = await get_memory_fragment() - - assert isinstance(fragment, CogneeGraph) - assert len(fragment.nodes) == 0 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_deduplicates_node_ids(): - """Test that duplicate node IDs across collections are deduplicated.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [ - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - ] - elif collection_name == "TextSummary_text": - return [ - MockScoredResult("node1", 0.90), - MockScoredResult("node3", 0.92), - ] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} - assert len(call_kwargs["relevant_ids_to_filter"]) == 3 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_excludes_edge_collection(): - """Test that EdgeType_relationship_name collection is excluded from ID extraction.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [MockScoredResult("node1", 0.95)] - elif collection_name == "EdgeType_relationship_name": - return [MockScoredResult("edge1", 0.88)] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search( - query="test", - node_name=None, - collections=["Entity_name", "EdgeType_relationship_name"], - ) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert call_kwargs["relevant_ids_to_filter"] == ["node1"] - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_skips_nodes_without_ids(): - """Test that nodes without ID attribute are skipped.""" - - class ScoredResultNoId: - """Mock result without id attribute.""" - - def __init__(self, score): - self.score = score - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [ - MockScoredResult("node1", 0.95), - ScoredResultNoId(0.90), - MockScoredResult("node2", 0.87), - ] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_handles_tuple_results(): - """Test that both list and tuple results are handled correctly.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return ( - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - ) - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_mixed_empty_collections(): - """Test ID extraction with mixed empty and non-empty collections.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [MockScoredResult("node1", 0.95)] - elif collection_name == "TextSummary_text": - return [] - elif collection_name == "EntityType_name": - return [MockScoredResult("node2", 0.92)] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} - - -def test_format_triplets(): - """Test format_triplets function.""" - mock_edge = MagicMock() - mock_node1 = MagicMock() - mock_node2 = MagicMock() - - mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"} - mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"} - mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"} - - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - result = format_triplets([mock_edge]) - - assert isinstance(result, str) - assert "Node1" in result - assert "Node2" in result - assert "relates_to" in result - assert "connects" in result - - -def test_format_triplets_with_none_values(): - """Test format_triplets filters out None values.""" - mock_edge = MagicMock() - mock_node1 = MagicMock() - mock_node2 = MagicMock() - - mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"} - mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None} - mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None} - - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - result = format_triplets([mock_edge]) - - assert "Node1" in result - assert "Node2" in result - assert "relates_to" in result - assert "None" not in result or result.count("None") == 0 - - -def test_format_triplets_with_nested_dict(): - """Test format_triplets handles nested dict attributes (lines 23-35).""" - mock_edge = MagicMock() - mock_node1 = MagicMock() - mock_node2 = MagicMock() - - mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}} - mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}} - mock_edge.attributes = {"relationship_name": "relates_to"} - - mock_edge.node1 = mock_node1 - mock_edge.node2 = mock_node2 - - result = format_triplets([mock_edge]) - - assert isinstance(result, str) - assert "Node1" in result - assert "Node2" in result - assert "relates_to" in result - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_vector_engine_init_error(): - """Test brute_force_triplet_search handles vector engine initialization error (lines 145-147).""" - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine" - ) as mock_get_vector_engine, - ): - mock_get_vector_engine.side_effect = Exception("Initialization error") - - with pytest.raises(RuntimeError, match="Initialization error"): - await brute_force_triplet_search(query="test query") - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_collection_not_found_error(): - """Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157).""" - mock_vector_engine = AsyncMock() - mock_embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine = mock_embedding_engine - mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - - mock_vector_engine.search = AsyncMock( - side_effect=[ - CollectionNotFoundError("Collection not found"), - [], - [], - ] - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=CogneeGraph(), - ), - ): - result = await brute_force_triplet_search( - query="test query", collections=["missing_collection", "existing_collection"] - ) - - assert result == [] - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_generic_exception(): - """Test brute_force_triplet_search handles generic exceptions (lines 209-217).""" - mock_vector_engine = AsyncMock() - mock_embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine = mock_embedding_engine - mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - - mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error")) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - ): - with pytest.raises(Exception, match="Generic error"): - await brute_force_triplet_search(query="test query") - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none(): - """Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191).""" - mock_vector_engine = AsyncMock() - mock_embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine = mock_embedding_engine - mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - - mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) - mock_vector_engine.search = AsyncMock(return_value=[mock_result]) - - mock_fragment = AsyncMock() - mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() - mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() - mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[]) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment, - ): - await brute_force_triplet_search(query="test query", node_name=["Node1"]) - - assert mock_get_fragment.called - call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {} - assert call_kwargs.get("relevant_ids_to_filter") is None - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_collection_not_found_at_top_level(): - """Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210).""" - mock_vector_engine = AsyncMock() - mock_embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine = mock_embedding_engine - mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - - mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) - mock_vector_engine.search = AsyncMock(return_value=[mock_result]) - - mock_fragment = AsyncMock() - mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() - mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() - mock_fragment.calculate_top_triplet_importances = AsyncMock( - side_effect=CollectionNotFoundError("Collection not found") - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ), - ): - result = await brute_force_triplet_search(query="test query") - - assert result == [] diff --git a/cognee/tests/unit/modules/retrieval/test_completion.py b/cognee/tests/unit/modules/retrieval/test_completion.py deleted file mode 100644 index 9a836c2cc..000000000 --- a/cognee/tests/unit/modules/retrieval/test_completion.py +++ /dev/null @@ -1,343 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from typing import Type - - -class TestGenerateCompletion: - @pytest.mark.asyncio - async def test_generate_completion_with_system_prompt(self): - """Test generate_completion with provided system_prompt.""" - mock_llm_response = "Generated answer" - - with ( - patch( - "cognee.modules.retrieval.utils.completion.render_prompt", - return_value="User prompt text", - ), - patch( - "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_llm_response, - ) as mock_llm, - ): - from cognee.modules.retrieval.utils.completion import generate_completion - - result = await generate_completion( - query="What is AI?", - context="AI is artificial intelligence", - user_prompt_path="user_prompt.txt", - system_prompt_path="system_prompt.txt", - system_prompt="Custom system prompt", - ) - - assert result == mock_llm_response - mock_llm.assert_awaited_once_with( - text_input="User prompt text", - system_prompt="Custom system prompt", - response_model=str, - ) - - @pytest.mark.asyncio - async def test_generate_completion_without_system_prompt(self): - """Test generate_completion reads system_prompt from file when not provided.""" - mock_llm_response = "Generated answer" - - with ( - patch( - "cognee.modules.retrieval.utils.completion.render_prompt", - return_value="User prompt text", - ), - patch( - "cognee.modules.retrieval.utils.completion.read_query_prompt", - return_value="System prompt from file", - ), - patch( - "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_llm_response, - ) as mock_llm, - ): - from cognee.modules.retrieval.utils.completion import generate_completion - - result = await generate_completion( - query="What is AI?", - context="AI is artificial intelligence", - user_prompt_path="user_prompt.txt", - system_prompt_path="system_prompt.txt", - ) - - assert result == mock_llm_response - mock_llm.assert_awaited_once_with( - text_input="User prompt text", - system_prompt="System prompt from file", - response_model=str, - ) - - @pytest.mark.asyncio - async def test_generate_completion_with_conversation_history(self): - """Test generate_completion includes conversation_history in system_prompt.""" - mock_llm_response = "Generated answer" - - with ( - patch( - "cognee.modules.retrieval.utils.completion.render_prompt", - return_value="User prompt text", - ), - patch( - "cognee.modules.retrieval.utils.completion.read_query_prompt", - return_value="System prompt from file", - ), - patch( - "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_llm_response, - ) as mock_llm, - ): - from cognee.modules.retrieval.utils.completion import generate_completion - - result = await generate_completion( - query="What is AI?", - context="AI is artificial intelligence", - user_prompt_path="user_prompt.txt", - system_prompt_path="system_prompt.txt", - conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning", - ) - - assert result == mock_llm_response - expected_system_prompt = ( - "Previous conversation:\nQ: What is ML?\nA: ML is machine learning" - + "\nTASK:" - + "System prompt from file" - ) - mock_llm.assert_awaited_once_with( - text_input="User prompt text", - system_prompt=expected_system_prompt, - response_model=str, - ) - - @pytest.mark.asyncio - async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self): - """Test generate_completion includes conversation_history with custom system_prompt.""" - mock_llm_response = "Generated answer" - - with ( - patch( - "cognee.modules.retrieval.utils.completion.render_prompt", - return_value="User prompt text", - ), - patch( - "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_llm_response, - ) as mock_llm, - ): - from cognee.modules.retrieval.utils.completion import generate_completion - - result = await generate_completion( - query="What is AI?", - context="AI is artificial intelligence", - user_prompt_path="user_prompt.txt", - system_prompt_path="system_prompt.txt", - system_prompt="Custom system prompt", - conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning", - ) - - assert result == mock_llm_response - expected_system_prompt = ( - "Previous conversation:\nQ: What is ML?\nA: ML is machine learning" - + "\nTASK:" - + "Custom system prompt" - ) - mock_llm.assert_awaited_once_with( - text_input="User prompt text", - system_prompt=expected_system_prompt, - response_model=str, - ) - - @pytest.mark.asyncio - async def test_generate_completion_with_response_model(self): - """Test generate_completion with custom response_model.""" - mock_response_model = MagicMock() - mock_llm_response = {"answer": "Generated answer"} - - with ( - patch( - "cognee.modules.retrieval.utils.completion.render_prompt", - return_value="User prompt text", - ), - patch( - "cognee.modules.retrieval.utils.completion.read_query_prompt", - return_value="System prompt from file", - ), - patch( - "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_llm_response, - ) as mock_llm, - ): - from cognee.modules.retrieval.utils.completion import generate_completion - - result = await generate_completion( - query="What is AI?", - context="AI is artificial intelligence", - user_prompt_path="user_prompt.txt", - system_prompt_path="system_prompt.txt", - response_model=mock_response_model, - ) - - assert result == mock_llm_response - mock_llm.assert_awaited_once_with( - text_input="User prompt text", - system_prompt="System prompt from file", - response_model=mock_response_model, - ) - - @pytest.mark.asyncio - async def test_generate_completion_render_prompt_args(self): - """Test generate_completion passes correct args to render_prompt.""" - mock_llm_response = "Generated answer" - - with ( - patch( - "cognee.modules.retrieval.utils.completion.render_prompt", - return_value="User prompt text", - ) as mock_render, - patch( - "cognee.modules.retrieval.utils.completion.read_query_prompt", - return_value="System prompt from file", - ), - patch( - "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_llm_response, - ), - ): - from cognee.modules.retrieval.utils.completion import generate_completion - - await generate_completion( - query="What is AI?", - context="AI is artificial intelligence", - user_prompt_path="user_prompt.txt", - system_prompt_path="system_prompt.txt", - ) - - mock_render.assert_called_once_with( - "user_prompt.txt", - {"question": "What is AI?", "context": "AI is artificial intelligence"}, - ) - - -class TestSummarizeText: - @pytest.mark.asyncio - async def test_summarize_text_with_system_prompt(self): - """Test summarize_text with provided system_prompt.""" - mock_llm_response = "Summary text" - - with patch( - "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_llm_response, - ) as mock_llm: - from cognee.modules.retrieval.utils.completion import summarize_text - - result = await summarize_text( - text="Long text to summarize", - system_prompt_path="summarize_search_results.txt", - system_prompt="Custom summary prompt", - ) - - assert result == mock_llm_response - mock_llm.assert_awaited_once_with( - text_input="Long text to summarize", - system_prompt="Custom summary prompt", - response_model=str, - ) - - @pytest.mark.asyncio - async def test_summarize_text_without_system_prompt(self): - """Test summarize_text reads system_prompt from file when not provided.""" - mock_llm_response = "Summary text" - - with ( - patch( - "cognee.modules.retrieval.utils.completion.read_query_prompt", - return_value="System prompt from file", - ), - patch( - "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_llm_response, - ) as mock_llm, - ): - from cognee.modules.retrieval.utils.completion import summarize_text - - result = await summarize_text( - text="Long text to summarize", - system_prompt_path="summarize_search_results.txt", - ) - - assert result == mock_llm_response - mock_llm.assert_awaited_once_with( - text_input="Long text to summarize", - system_prompt="System prompt from file", - response_model=str, - ) - - @pytest.mark.asyncio - async def test_summarize_text_default_prompt_path(self): - """Test summarize_text uses default prompt path when not provided.""" - mock_llm_response = "Summary text" - - with ( - patch( - "cognee.modules.retrieval.utils.completion.read_query_prompt", - return_value="Default system prompt", - ) as mock_read, - patch( - "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_llm_response, - ) as mock_llm, - ): - from cognee.modules.retrieval.utils.completion import summarize_text - - result = await summarize_text(text="Long text to summarize") - - assert result == mock_llm_response - mock_read.assert_called_once_with("summarize_search_results.txt") - mock_llm.assert_awaited_once_with( - text_input="Long text to summarize", - system_prompt="Default system prompt", - response_model=str, - ) - - @pytest.mark.asyncio - async def test_summarize_text_custom_prompt_path(self): - """Test summarize_text uses custom prompt path when provided.""" - mock_llm_response = "Summary text" - - with ( - patch( - "cognee.modules.retrieval.utils.completion.read_query_prompt", - return_value="Custom system prompt", - ) as mock_read, - patch( - "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_llm_response, - ) as mock_llm, - ): - from cognee.modules.retrieval.utils.completion import summarize_text - - result = await summarize_text( - text="Long text to summarize", - system_prompt_path="custom_prompt.txt", - ) - - assert result == mock_llm_response - mock_read.assert_called_once_with("custom_prompt.txt") - mock_llm.assert_awaited_once_with( - text_input="Long text to summarize", - system_prompt="Custom system prompt", - response_model=str, - ) diff --git a/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py deleted file mode 100644 index 2af10da5e..000000000 --- a/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +++ /dev/null @@ -1,157 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock - -from cognee.modules.retrieval.graph_summary_completion_retriever import ( - GraphSummaryCompletionRetriever, -) -from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge - - -@pytest.fixture -def mock_edge(): - """Create a mock edge.""" - edge = MagicMock(spec=Edge) - return edge - - -class TestGraphSummaryCompletionRetriever: - @pytest.mark.asyncio - async def test_init_defaults(self): - """Test GraphSummaryCompletionRetriever initialization with defaults.""" - retriever = GraphSummaryCompletionRetriever() - - assert retriever.summarize_prompt_path == "summarize_search_results.txt" - assert retriever.user_prompt_path == "graph_context_for_question.txt" - assert retriever.system_prompt_path == "answer_simple_question.txt" - assert retriever.top_k == 5 - assert retriever.save_interaction is False - - @pytest.mark.asyncio - async def test_init_custom_params(self): - """Test GraphSummaryCompletionRetriever initialization with custom parameters.""" - retriever = GraphSummaryCompletionRetriever( - user_prompt_path="custom_user.txt", - system_prompt_path="custom_system.txt", - summarize_prompt_path="custom_summarize.txt", - system_prompt="Custom system prompt", - top_k=10, - save_interaction=True, - wide_search_top_k=200, - triplet_distance_penalty=2.5, - ) - - assert retriever.summarize_prompt_path == "custom_summarize.txt" - assert retriever.user_prompt_path == "custom_user.txt" - assert retriever.system_prompt_path == "custom_system.txt" - assert retriever.top_k == 10 - assert retriever.save_interaction is True - - @pytest.mark.asyncio - async def test_resolve_edges_to_text_calls_super_and_summarizes(self, mock_edge): - """Test resolve_edges_to_text calls super method and then summarizes.""" - retriever = GraphSummaryCompletionRetriever( - summarize_prompt_path="custom_summarize.txt", - system_prompt="Custom system prompt", - ) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", - new_callable=AsyncMock, - return_value="Resolved edges text", - ) as mock_super_resolve, - patch( - "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", - new_callable=AsyncMock, - return_value="Summarized text", - ) as mock_summarize, - ): - result = await retriever.resolve_edges_to_text([mock_edge]) - - assert result == "Summarized text" - mock_super_resolve.assert_awaited_once_with([mock_edge]) - mock_summarize.assert_awaited_once_with( - "Resolved edges text", - "custom_summarize.txt", - "Custom system prompt", - ) - - @pytest.mark.asyncio - async def test_resolve_edges_to_text_with_default_system_prompt(self, mock_edge): - """Test resolve_edges_to_text uses None for system_prompt when not provided.""" - retriever = GraphSummaryCompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", - new_callable=AsyncMock, - return_value="Resolved edges text", - ), - patch( - "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", - new_callable=AsyncMock, - return_value="Summarized text", - ) as mock_summarize, - ): - await retriever.resolve_edges_to_text([mock_edge]) - - mock_summarize.assert_awaited_once_with( - "Resolved edges text", - "summarize_search_results.txt", - None, - ) - - @pytest.mark.asyncio - async def test_resolve_edges_to_text_with_empty_edges(self): - """Test resolve_edges_to_text handles empty edges list.""" - retriever = GraphSummaryCompletionRetriever() - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", - new_callable=AsyncMock, - return_value="", - ), - patch( - "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", - new_callable=AsyncMock, - return_value="Empty summary", - ) as mock_summarize, - ): - result = await retriever.resolve_edges_to_text([]) - - assert result == "Empty summary" - mock_summarize.assert_awaited_once_with( - "", - "summarize_search_results.txt", - None, - ) - - @pytest.mark.asyncio - async def test_resolve_edges_to_text_with_multiple_edges(self, mock_edge): - """Test resolve_edges_to_text handles multiple edges.""" - retriever = GraphSummaryCompletionRetriever() - - mock_edge2 = MagicMock(spec=Edge) - mock_edge3 = MagicMock(spec=Edge) - - with ( - patch( - "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", - new_callable=AsyncMock, - return_value="Multiple edges resolved text", - ), - patch( - "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", - new_callable=AsyncMock, - return_value="Multiple edges summarized", - ) as mock_summarize, - ): - result = await retriever.resolve_edges_to_text([mock_edge, mock_edge2, mock_edge3]) - - assert result == "Multiple edges summarized" - mock_summarize.assert_awaited_once_with( - "Multiple edges resolved text", - "summarize_search_results.txt", - None, - ) diff --git a/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py deleted file mode 100644 index a1e746bb9..000000000 --- a/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +++ /dev/null @@ -1,312 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from uuid import UUID, NAMESPACE_OID, uuid5 - -from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback -from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment -from cognee.modules.engine.models import NodeSet - - -@pytest.fixture -def mock_feedback_evaluation(): - """Create a mock feedback evaluation.""" - evaluation = MagicMock(spec=UserFeedbackEvaluation) - evaluation.evaluation = MagicMock() - evaluation.evaluation.value = "positive" - evaluation.score = 4.5 - return evaluation - - -@pytest.fixture -def mock_graph_engine(): - """Create a mock graph engine.""" - engine = AsyncMock() - engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) - engine.add_edges = AsyncMock() - engine.apply_feedback_weight = AsyncMock() - return engine - - -class TestUserQAFeedback: - @pytest.mark.asyncio - async def test_init_default(self): - """Test UserQAFeedback initialization with default last_k.""" - retriever = UserQAFeedback() - assert retriever.last_k == 1 - - @pytest.mark.asyncio - async def test_init_custom_last_k(self): - """Test UserQAFeedback initialization with custom last_k.""" - retriever = UserQAFeedback(last_k=5) - assert retriever.last_k == 5 - - @pytest.mark.asyncio - async def test_add_feedback_success_with_relationships( - self, mock_feedback_evaluation, mock_graph_engine - ): - """Test add_feedback successfully creates feedback with relationships.""" - interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000")) - interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001")) - mock_graph_engine.get_last_user_interaction_ids = AsyncMock( - return_value=[interaction_id_1, interaction_id_2] - ) - - feedback_text = "This answer was helpful" - - with ( - patch( - "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_feedback_evaluation, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.add_data_points", - new_callable=AsyncMock, - ) as mock_add_data, - patch( - "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", - new_callable=AsyncMock, - ) as mock_index_edges, - ): - retriever = UserQAFeedback(last_k=2) - result = await retriever.add_feedback(feedback_text) - - assert result == [feedback_text] - mock_add_data.assert_awaited_once() - mock_graph_engine.add_edges.assert_awaited_once() - mock_index_edges.assert_awaited_once() - mock_graph_engine.apply_feedback_weight.assert_awaited_once() - - # Verify add_edges was called with correct relationships - call_args = mock_graph_engine.add_edges.call_args[0][0] - assert len(call_args) == 2 - assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text) - assert call_args[0][1] == UUID(interaction_id_1) - assert call_args[0][2] == "gives_feedback_to" - assert call_args[0][3]["relationship_name"] == "gives_feedback_to" - assert call_args[0][3]["ontology_valid"] is False - - # Verify apply_feedback_weight was called with correct node IDs - weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"] - assert len(weight_call_args) == 2 - assert interaction_id_1 in weight_call_args - assert interaction_id_2 in weight_call_args - - @pytest.mark.asyncio - async def test_add_feedback_success_no_relationships( - self, mock_feedback_evaluation, mock_graph_engine - ): - """Test add_feedback successfully creates feedback without relationships.""" - mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) - - feedback_text = "This answer was helpful" - - with ( - patch( - "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_feedback_evaluation, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.add_data_points", - new_callable=AsyncMock, - ) as mock_add_data, - patch( - "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", - new_callable=AsyncMock, - ) as mock_index_edges, - ): - retriever = UserQAFeedback(last_k=1) - result = await retriever.add_feedback(feedback_text) - - assert result == [feedback_text] - mock_add_data.assert_awaited_once() - # Should not call add_edges or index_graph_edges when no relationships - mock_graph_engine.add_edges.assert_not_awaited() - mock_index_edges.assert_not_awaited() - mock_graph_engine.apply_feedback_weight.assert_not_awaited() - - @pytest.mark.asyncio - async def test_add_feedback_creates_correct_feedback_node( - self, mock_feedback_evaluation, mock_graph_engine - ): - """Test add_feedback creates CogneeUserFeedback with correct attributes.""" - mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) - - feedback_text = "This was a negative experience" - mock_feedback_evaluation.evaluation.value = "negative" - mock_feedback_evaluation.score = -3.0 - - with ( - patch( - "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_feedback_evaluation, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.add_data_points", - new_callable=AsyncMock, - ) as mock_add_data, - ): - retriever = UserQAFeedback() - await retriever.add_feedback(feedback_text) - - # Verify add_data_points was called with correct CogneeUserFeedback - call_args = mock_add_data.call_args[1]["data_points"] - assert len(call_args) == 1 - feedback_node = call_args[0] - assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text) - assert feedback_node.feedback == feedback_text - assert feedback_node.sentiment == "negative" - assert feedback_node.score == -3.0 - assert isinstance(feedback_node.belongs_to_set, NodeSet) - assert feedback_node.belongs_to_set.name == "UserQAFeedbacks" - - @pytest.mark.asyncio - async def test_add_feedback_calls_llm_with_correct_prompt( - self, mock_feedback_evaluation, mock_graph_engine - ): - """Test add_feedback calls LLM with correct sentiment analysis prompt.""" - mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) - - feedback_text = "Great answer!" - - with ( - patch( - "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_feedback_evaluation, - ) as mock_llm, - patch( - "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.add_data_points", - new_callable=AsyncMock, - ), - ): - retriever = UserQAFeedback() - await retriever.add_feedback(feedback_text) - - mock_llm.assert_awaited_once() - call_kwargs = mock_llm.call_args[1] - assert call_kwargs["text_input"] == feedback_text - assert "sentiment analysis assistant" in call_kwargs["system_prompt"] - assert call_kwargs["response_model"] == UserFeedbackEvaluation - - @pytest.mark.asyncio - async def test_add_feedback_uses_last_k_parameter( - self, mock_feedback_evaluation, mock_graph_engine - ): - """Test add_feedback uses last_k parameter when getting interaction IDs.""" - mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) - - feedback_text = "Test feedback" - - with ( - patch( - "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_feedback_evaluation, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.add_data_points", - new_callable=AsyncMock, - ), - ): - retriever = UserQAFeedback(last_k=5) - await retriever.add_feedback(feedback_text) - - mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5) - - @pytest.mark.asyncio - async def test_add_feedback_with_single_interaction( - self, mock_feedback_evaluation, mock_graph_engine - ): - """Test add_feedback with single interaction ID.""" - interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000")) - mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id]) - - feedback_text = "Test feedback" - - with ( - patch( - "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_feedback_evaluation, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.add_data_points", - new_callable=AsyncMock, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", - new_callable=AsyncMock, - ), - ): - retriever = UserQAFeedback() - result = await retriever.add_feedback(feedback_text) - - assert result == [feedback_text] - # Should create relationship for the interaction - call_args = mock_graph_engine.add_edges.call_args[0][0] - assert len(call_args) == 1 - assert call_args[0][1] == UUID(interaction_id) - - @pytest.mark.asyncio - async def test_add_feedback_applies_weight_correctly( - self, mock_feedback_evaluation, mock_graph_engine - ): - """Test add_feedback applies feedback weight with correct score.""" - interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000")) - mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id]) - mock_feedback_evaluation.score = 4.5 - - feedback_text = "Positive feedback" - - with ( - patch( - "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", - new_callable=AsyncMock, - return_value=mock_feedback_evaluation, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.add_data_points", - new_callable=AsyncMock, - ), - patch( - "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", - new_callable=AsyncMock, - ), - ): - retriever = UserQAFeedback() - await retriever.add_feedback(feedback_text) - - mock_graph_engine.apply_feedback_weight.assert_awaited_once_with( - node_ids=[interaction_id], weight=4.5 - ) diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py deleted file mode 100644 index 83612c7aa..000000000 --- a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +++ /dev/null @@ -1,329 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock - -from cognee.modules.retrieval.triplet_retriever import TripletRetriever -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError - - -@pytest.fixture -def mock_vector_engine(): - """Create a mock vector engine.""" - engine = AsyncMock() - engine.has_collection = AsyncMock(return_value=True) - engine.search = AsyncMock() - return engine - - -@pytest.mark.asyncio -async def test_get_context_success(mock_vector_engine): - """Test successful retrieval of triplet context.""" - mock_result1 = MagicMock() - mock_result1.payload = {"text": "Alice knows Bob"} - mock_result2 = MagicMock() - mock_result2.payload = {"text": "Bob works at Tech Corp"} - - mock_vector_engine.search.return_value = [mock_result1, mock_result2] - - retriever = TripletRetriever(top_k=5) - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "Alice knows Bob\nBob works at Tech Corp" - mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5) - - -@pytest.mark.asyncio -async def test_get_context_no_collection(mock_vector_engine): - """Test that NoDataError is raised when Triplet_text collection doesn't exist.""" - mock_vector_engine.has_collection.return_value = False - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="create_triplet_embeddings"): - await retriever.get_context("test query") - - -@pytest.mark.asyncio -async def test_get_context_empty_results(mock_vector_engine): - """Test that empty string is returned when no triplets are found.""" - mock_vector_engine.search.return_value = [] - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "" - - -@pytest.mark.asyncio -async def test_get_context_collection_not_found_error(mock_vector_engine): - """Test that CollectionNotFoundError is converted to NoDataError.""" - mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found") - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="No data found"): - await retriever.get_context("test query") - - -@pytest.mark.asyncio -async def test_get_context_empty_payload_text(mock_vector_engine): - """Test get_context handles missing text in payload.""" - mock_result = MagicMock() - mock_result.payload = {} - - mock_vector_engine.search.return_value = [mock_result] - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(KeyError): - await retriever.get_context("test query") - - -@pytest.mark.asyncio -async def test_get_context_single_triplet(mock_vector_engine): - """Test get_context with single triplet result.""" - mock_result = MagicMock() - mock_result.payload = {"text": "Single triplet"} - - mock_vector_engine.search.return_value = [mock_result] - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "Single triplet" - - -@pytest.mark.asyncio -async def test_init_defaults(): - """Test TripletRetriever initialization with defaults.""" - retriever = TripletRetriever() - - assert retriever.user_prompt_path == "context_for_question.txt" - assert retriever.system_prompt_path == "answer_simple_question.txt" - assert retriever.top_k == 5 # Default is 5 - assert retriever.system_prompt is None - - -@pytest.mark.asyncio -async def test_init_custom_params(): - """Test TripletRetriever initialization with custom parameters.""" - retriever = TripletRetriever( - user_prompt_path="custom_user.txt", - system_prompt_path="custom_system.txt", - system_prompt="Custom prompt", - top_k=10, - ) - - assert retriever.user_prompt_path == "custom_user.txt" - assert retriever.system_prompt_path == "custom_system.txt" - assert retriever.system_prompt == "Custom prompt" - assert retriever.top_k == 10 - - -@pytest.mark.asyncio -async def test_get_completion_without_context(mock_vector_engine): - """Test get_completion retrieves context when not provided.""" - mock_result = MagicMock() - mock_result.payload = {"text": "Test triplet"} - mock_vector_engine.has_collection.return_value = True - mock_vector_engine.search.return_value = [mock_result] - - retriever = TripletRetriever() - - with ( - patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.triplet_retriever.generate_completion", - return_value="Generated answer", - ), - patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_provided_context(mock_vector_engine): - """Test get_completion uses provided context.""" - retriever = TripletRetriever() - - with ( - patch( - "cognee.modules.retrieval.triplet_retriever.generate_completion", - return_value="Generated answer", - ), - patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", context="Provided context") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - - -@pytest.mark.asyncio -async def test_get_completion_with_session(mock_vector_engine): - """Test get_completion with session caching enabled.""" - mock_result = MagicMock() - mock_result.payload = {"text": "Test triplet"} - mock_vector_engine.has_collection.return_value = True - mock_vector_engine.search.return_value = [mock_result] - - retriever = TripletRetriever() - - mock_user = MagicMock() - mock_user.id = "test-user-id" - - with ( - patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.triplet_retriever.get_conversation_history", - return_value="Previous conversation", - ), - patch( - "cognee.modules.retrieval.triplet_retriever.summarize_text", - return_value="Context summary", - ), - patch( - "cognee.modules.retrieval.triplet_retriever.generate_completion", - return_value="Generated answer", - ), - patch( - "cognee.modules.retrieval.triplet_retriever.save_conversation_history", - ) as mock_save, - patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, - patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = mock_user - - completion = await retriever.get_completion("test query", session_id="test_session") - - assert isinstance(completion, list) - assert len(completion) == 1 - assert completion[0] == "Generated answer" - mock_save.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_get_completion_with_session_no_user_id(mock_vector_engine): - """Test get_completion with session config but no user ID.""" - mock_result = MagicMock() - mock_result.payload = {"text": "Test triplet"} - mock_vector_engine.has_collection.return_value = True - mock_vector_engine.search.return_value = [mock_result] - - retriever = TripletRetriever() - - with ( - patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.triplet_retriever.generate_completion", - return_value="Generated answer", - ), - patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, - patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, - ): - mock_config = MagicMock() - mock_config.caching = True - mock_cache_config.return_value = mock_config - mock_session_user.get.return_value = None # No user - - completion = await retriever.get_completion("test query") - - assert isinstance(completion, list) - assert len(completion) == 1 - - -@pytest.mark.asyncio -async def test_get_completion_with_response_model(mock_vector_engine): - """Test get_completion with custom response model.""" - from pydantic import BaseModel - - class TestModel(BaseModel): - answer: str - - mock_result = MagicMock() - mock_result.payload = {"text": "Test triplet"} - mock_vector_engine.has_collection.return_value = True - mock_vector_engine.search.return_value = [mock_result] - - retriever = TripletRetriever() - - with ( - patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.triplet_retriever.generate_completion", - return_value=TestModel(answer="Test answer"), - ), - patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, - ): - mock_config = MagicMock() - mock_config.caching = False - mock_cache_config.return_value = mock_config - - completion = await retriever.get_completion("test query", response_model=TestModel) - - assert isinstance(completion, list) - assert len(completion) == 1 - assert isinstance(completion[0], TestModel) - - -@pytest.mark.asyncio -async def test_init_none_top_k(): - """Test TripletRetriever initialization with None top_k.""" - retriever = TripletRetriever(top_k=None) - - assert retriever.top_k == 5