From 1c8d0f6da1416b0747c3c75f5f6355beb8100243 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 12 Jan 2026 14:51:50 +0100 Subject: [PATCH] chore: update tests and minor tweaks --- .../utils/node_edge_vector_search.py | 7 ++- .../unit/modules/graph/cognee_graph_test.py | 46 +++++++++++++++++++ .../retrieval/test_node_edge_vector_search.py | 33 +++++++++++++ 3 files changed, 85 insertions(+), 1 deletion(-) diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py index ff2d98eb8..80116f6f2 100644 --- a/cognee/modules/retrieval/utils/node_edge_vector_search.py +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -49,7 +49,12 @@ class NodeEdgeVectorSearch: search_results: List[List[Any]], query_list_length: Optional[int] = None, ): - """Separates search results into node and edge distances with stable shapes.""" + """Separates search results into node and edge distances with stable shapes. + + Ensures all collections are present in the output, even if empty: + - Batch mode: missing/empty collections become [[]] * query_list_length + - Single mode: missing/empty collections become [] + """ self.node_distances = {} self.edge_distances = ( [] if query_list_length is None else [[] for _ in range(query_list_length)] diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 41f12e73a..a13031ac5 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -718,3 +718,49 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set with pytest.raises(ValueError): await graph.calculate_top_triplet_importances(k=1, query_list_length=1) + + +def test_normalize_query_distance_lists_flat_list_single_query(setup_graph): + """Test that flat list is normalized to list-of-lists with length 1 for single-query mode.""" + graph = setup_graph + flat_list = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)] + + result = graph._normalize_query_distance_lists(flat_list, query_list_length=None, name="test") + + assert len(result) == 1 + assert result[0] == flat_list + + +def test_normalize_query_distance_lists_nested_list_batch_mode(setup_graph): + """Test that nested list is used as-is when query_list_length matches.""" + graph = setup_graph + nested_list = [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + + result = graph._normalize_query_distance_lists(nested_list, query_list_length=2, name="test") + + assert len(result) == 2 + assert result == nested_list + + +def test_normalize_query_distance_lists_raises_on_length_mismatch(setup_graph): + """Test that ValueError is raised when nested list length doesn't match query_list_length.""" + graph = setup_graph + nested_list = [ + [MockScoredResult("node1", 0.95)], + [MockScoredResult("node2", 0.87)], + ] + + with pytest.raises(ValueError, match="test has 2 query lists, but query_list_length is 3"): + graph._normalize_query_distance_lists(nested_list, query_list_length=3, name="test") + + +def test_normalize_query_distance_lists_empty_list(setup_graph): + """Test that empty list returns empty list.""" + graph = setup_graph + + result = graph._normalize_query_distance_lists([], query_list_length=None, name="test") + + assert result == [] diff --git a/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py index 1fd169fcc..98d76ddef 100644 --- a/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py +++ b/cognee/tests/unit/modules/retrieval/test_node_edge_vector_search.py @@ -214,6 +214,39 @@ async def test_node_edge_vector_search_single_query_collection_not_found(): assert vector_search.node_distances["MissingCollection"] == [] +@pytest.mark.asyncio +async def test_node_edge_vector_search_missing_collections_single_query(): + """Test that missing collections in single-query mode are handled gracefully with empty lists.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + node_result = MockScoredResult("node1", 0.95) + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [node_result] + elif collection_name == "MissingCollection": + raise CollectionNotFoundError("Collection not found") + return [] + + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine) + collections = ["Entity_name", "MissingCollection", "EmptyCollection"] + + await vector_search.embed_and_retrieve_distances( + query="test query", query_batch=None, collections=collections, wide_search_limit=10 + ) + + assert len(vector_search.node_distances["Entity_name"]) == 1 + assert vector_search.node_distances["Entity_name"][0].id == "node1" + assert vector_search.node_distances["Entity_name"][0].score == 0.95 + assert vector_search.node_distances["MissingCollection"] == [] + assert vector_search.node_distances["EmptyCollection"] == [] + + @pytest.mark.asyncio async def test_node_edge_vector_search_has_results_batch_nodes_only(): """Test has_results returns True when only node distances are populated in batch mode."""