diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py index 00db1e794..fcbfd2434 100644 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -30,7 +30,7 @@ async def test_brute_force_triplet_search_empty_query(): @pytest.mark.asyncio async def test_brute_force_triplet_search_none_query(): """Test that None query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): + with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'."): await brute_force_triplet_search(query=None) @@ -351,7 +351,9 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation custom_top_k = 15 await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) - mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) + mock_fragment.calculate_top_triplet_importances.assert_called_once_with( + k=custom_top_k, query_list_length=None + ) @pytest.mark.asyncio @@ -815,3 +817,237 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level(): result = await brute_force_triplet_search(query="test query") assert result == [] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_single_query_regression(): + """Test that single-query mode maintains legacy behavior (flat list, ID filtering).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("node1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + result = await brute_force_triplet_search( + query="q1", query_batch=None, wide_search_top_k=10, node_name=None + ) + + assert isinstance(result, list) + assert not (result and isinstance(result[0], list)) + mock_get_fragment.assert_called_once() + call_kwargs = mock_get_fragment.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] is not None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_batch_wiring_happy_path(): + """Test that batch mode returns list-of-lists and skips ID filtering.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.batch_search = AsyncMock( + return_value=[ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + ) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[[], []]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + result = await brute_force_triplet_search(query_batch=["q1", "q2"]) + + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], list) + assert isinstance(result[1], list) + mock_get_fragment.assert_called_once() + call_kwargs = mock_get_fragment.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_shape_propagation_to_graph(): + """Test that query_list_length is passed through to graph mapping methods.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.batch_search = AsyncMock( + return_value=[ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + ) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[[], []]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + await brute_force_triplet_search(query_batch=["q1", "q2"]) + + mock_fragment.map_vector_distances_to_graph_nodes.assert_called_once() + node_call_kwargs = mock_fragment.map_vector_distances_to_graph_nodes.call_args[1] + assert "query_list_length" in node_call_kwargs + assert node_call_kwargs["query_list_length"] == 2 + + mock_fragment.map_vector_distances_to_graph_edges.assert_called_once() + edge_call_kwargs = mock_fragment.map_vector_distances_to_graph_edges.call_args[1] + assert "query_list_length" in edge_call_kwargs + assert edge_call_kwargs["query_list_length"] == 2 + + mock_fragment.calculate_top_triplet_importances.assert_called_once() + importance_call_kwargs = mock_fragment.calculate_top_triplet_importances.call_args[1] + assert "query_list_length" in importance_call_kwargs + assert importance_call_kwargs["query_list_length"] == 2 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_batch_path_comprehensive(): + """Test batch mode: returns list-of-lists, skips ID filtering, passes None for wide_search_limit.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + + def batch_search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + elif collection_name == "EdgeType_relationship_name": + return [ + [MockScoredResult("edge1", 0.92)], + [MockScoredResult("edge2", 0.88)], + ] + return [[], []] + + mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[[], []]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + result = await brute_force_triplet_search( + query_batch=["q1", "q2"], collections=["Entity_name", "EdgeType_relationship_name"] + ) + + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], list) + assert isinstance(result[1], list) + + mock_get_fragment.assert_called_once() + fragment_call_kwargs = mock_get_fragment.call_args[1] + assert fragment_call_kwargs["relevant_ids_to_filter"] is None + + batch_search_calls = mock_vector_engine.batch_search.call_args_list + assert len(batch_search_calls) > 0 + for call in batch_search_calls: + assert call[1]["limit"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_batch_error_fallback(): + """Test that CollectionNotFoundError in batch mode returns [[], []] matching batch length.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.batch_search = AsyncMock( + side_effect=CollectionNotFoundError("Collection not found") + ) + + with patch( + "cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine", + return_value=mock_vector_engine, + ): + result = await brute_force_triplet_search(query_batch=["q1", "q2"]) + + assert result == [[], []] + assert len(result) == 2 + + +@pytest.mark.asyncio +async def test_cognee_graph_mapping_batch_shapes(): + """Test that CogneeGraph mapping methods accept list-of-lists with query_list_length set.""" + from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge + + graph = CogneeGraph() + node1 = Node("node1", {"name": "Node1"}) + node2 = Node("node2", {"name": "Node2"}) + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge(node1, node2, attributes={"edge_text": "relates_to"}) + graph.add_edge(edge) + + node_distances_batch = { + "Entity_name": [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + } + + edge_distances_batch = [ + [MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})], + [MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})], + ] + + await graph.map_vector_distances_to_graph_nodes( + node_distances=node_distances_batch, query_list_length=2 + ) + await graph.map_vector_distances_to_graph_edges( + edge_distances=edge_distances_batch, query_list_length=2 + ) + + assert node1.attributes.get("vector_distance") == [0.95, 3.5] + assert node2.attributes.get("vector_distance") == [3.5, 0.87] + assert edge.attributes.get("vector_distance") == [0.92, 0.88] diff --git a/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py index d93dce42b..1fd169fcc 100644 --- a/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +++ b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py @@ -212,3 +212,29 @@ async def test_node_edge_vector_search_single_query_collection_not_found(): ) assert vector_search.node_distances["MissingCollection"] == [] + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_has_results_batch_nodes_only(): + """Test has_results returns True when only node distances are populated in batch mode.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = 2 + vector_search.edge_distances = [[], []] + vector_search.node_distances = { + "Entity_name": [[MockScoredResult("node1", 0.95)], []], + } + + assert vector_search.has_results() is True + + +@pytest.mark.asyncio +async def test_node_edge_vector_search_has_results_batch_edges_only(): + """Test has_results returns True when only edge distances are populated in batch mode.""" + mock_vector_engine = AsyncMock() + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + vector_search.query_list_length = 2 + vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []] + vector_search.node_distances = {} + + assert vector_search.has_results() is True