From aa263142b8e02b686378f76f3ea5f58058b05730 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 11 Dec 2025 10:58:58 +0100 Subject: [PATCH] increasing coverage of graph completion --- .../graph_completion_retriever_test.py | 428 ++++++++++++++++++ 1 file changed, 428 insertions(+) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index 70e0f7715..034b7016a 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -156,8 +156,436 @@ async def test_init_custom_params(): top_k=10, user_prompt_path="custom_user.txt", system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + node_type=str, + node_name=["node1"], + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=5.0, ) assert retriever.top_k == 10 assert retriever.user_prompt_path == "custom_user.txt" assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.node_type == str + assert retriever.node_name == ["node1"] + assert retriever.save_interaction is True + assert retriever.wide_search_top_k == 200 + assert retriever.triplet_distance_penalty == 5.0 + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test GraphCompletionRetriever initialization with None top_k.""" + retriever = GraphCompletionRetriever(top_k=None) + + assert retriever.top_k == 5 # None defaults to 5 + + +@pytest.mark.asyncio +async def test_convert_retrieved_objects_to_context(mock_edge): + """Test convert_retrieved_objects_to_context method.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved text", + ) as mock_resolve: + result = await retriever.convert_retrieved_objects_to_context([mock_edge]) + + assert result == "Resolved text" + mock_resolve.assert_awaited_once_with([mock_edge]) + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge]) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_empty_context(mock_edge): + """Test get_completion with empty context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_save_qa(mock_edge): + """Test save_qa method.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=["uuid1", "uuid2"], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[mock_edge], + ) + + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_save_qa_no_triplet_ids(mock_edge): + """Test save_qa when triplets have no extractable IDs.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + return_value=None, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[mock_edge], + ) + + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_not_called() + + +@pytest.mark.asyncio +async def test_save_qa_empty_triplets(): + """Test save_qa with empty triplets list.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[], + ) + + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_completion(mock_edge): + """Test get_completion with save_interaction but no completion.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=None, # No completion + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] is None + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_context(mock_edge): + """Test get_completion with save_interaction but no context provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=None) + + assert isinstance(completion, list) + assert len(completion) == 1