diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py index 3dc9f38d9..b7cbe08d7 100644 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -1,12 +1,14 @@ import pytest -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, patch, MagicMock from cognee.modules.retrieval.utils.brute_force_triplet_search import ( brute_force_triplet_search, get_memory_fragment, + format_triplets, ) from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError +from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError class MockScoredResult: @@ -354,20 +356,30 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation @pytest.mark.asyncio async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): - """Test that get_memory_fragment returns empty graph when entity not found.""" + """Test that get_memory_fragment returns empty graph when entity not found (line 85).""" mock_graph_engine = AsyncMock() - mock_graph_engine.project_graph_from_db = AsyncMock( + + # Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called + mock_fragment = MagicMock(spec=CogneeGraph) + mock_fragment.project_graph_from_db = AsyncMock( side_effect=EntityNotFoundError("Entity not found") ) - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph", + return_value=mock_fragment, + ), ): - fragment = await get_memory_fragment() + result = await get_memory_fragment() - assert isinstance(fragment, CogneeGraph) - assert len(fragment.nodes) == 0 + # Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85) + assert result == mock_fragment + mock_fragment.project_graph_from_db.assert_awaited_once() @pytest.mark.asyncio @@ -606,3 +618,200 @@ async def test_brute_force_triplet_search_mixed_empty_collections(): call_kwargs = mock_get_fragment_fn.call_args[1] assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +def test_format_triplets(): + """Test format_triplets function.""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"} + mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"} + mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert isinstance(result, str) + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + assert "connects" in result + + +def test_format_triplets_with_none_values(): + """Test format_triplets filters out None values.""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"} + mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None} + mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + assert "None" not in result or result.count("None") == 0 + + +def test_format_triplets_with_nested_dict(): + """Test format_triplets handles nested dict attributes (lines 23-35).""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}} + mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}} + mock_edge.attributes = {"relationship_name": "relates_to"} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert isinstance(result, str) + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_vector_engine_init_error(): + """Test brute_force_triplet_search handles vector engine initialization error (lines 145-147).""" + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine" + ) as mock_get_vector_engine, + ): + mock_get_vector_engine.side_effect = Exception("Initialization error") + + with pytest.raises(RuntimeError, match="Initialization error"): + await brute_force_triplet_search(query="test query") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_collection_not_found_error(): + """Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_vector_engine.search = AsyncMock( + side_effect=[ + CollectionNotFoundError("Collection not found"), + [], + [], + ] + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=CogneeGraph(), + ), + ): + result = await brute_force_triplet_search( + query="test query", collections=["missing_collection", "existing_collection"] + ) + + assert result == [] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_generic_exception(): + """Test brute_force_triplet_search handles generic exceptions (lines 209-217).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error")) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + with pytest.raises(Exception, match="Generic error"): + await brute_force_triplet_search(query="test query") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none(): + """Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) + mock_vector_engine.search = AsyncMock(return_value=[mock_result]) + + mock_fragment = AsyncMock() + mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() + mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() + mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test query", node_name=["Node1"]) + + assert mock_get_fragment.called + call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {} + assert call_kwargs.get("relevant_ids_to_filter") is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_collection_not_found_at_top_level(): + """Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) + mock_vector_engine.search = AsyncMock(return_value=[mock_result]) + + mock_fragment = AsyncMock() + mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() + mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() + mock_fragment.calculate_top_triplet_importances = AsyncMock( + side_effect=CollectionNotFoundError("Collection not found") + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + result = await brute_force_triplet_search(query="test query") + + assert result == []