chore: update tests and minor tweaks

This commit is contained in:
lxobr 2026-01-12 14:51:50 +01:00
parent c20304a92a
commit 1c8d0f6da1
3 changed files with 85 additions and 1 deletions

View file

@ -49,7 +49,12 @@ class NodeEdgeVectorSearch:
search_results: List[List[Any]],
query_list_length: Optional[int] = None,
):
"""Separates search results into node and edge distances with stable shapes."""
"""Separates search results into node and edge distances with stable shapes.
Ensures all collections are present in the output, even if empty:
- Batch mode: missing/empty collections become [[]] * query_list_length
- Single mode: missing/empty collections become []
"""
self.node_distances = {}
self.edge_distances = (
[] if query_list_length is None else [[] for _ in range(query_list_length)]

View file

@ -718,3 +718,49 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set
with pytest.raises(ValueError):
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
def test_normalize_query_distance_lists_flat_list_single_query(setup_graph):
"""Test that flat list is normalized to list-of-lists with length 1 for single-query mode."""
graph = setup_graph
flat_list = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
result = graph._normalize_query_distance_lists(flat_list, query_list_length=None, name="test")
assert len(result) == 1
assert result[0] == flat_list
def test_normalize_query_distance_lists_nested_list_batch_mode(setup_graph):
"""Test that nested list is used as-is when query_list_length matches."""
graph = setup_graph
nested_list = [
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
result = graph._normalize_query_distance_lists(nested_list, query_list_length=2, name="test")
assert len(result) == 2
assert result == nested_list
def test_normalize_query_distance_lists_raises_on_length_mismatch(setup_graph):
"""Test that ValueError is raised when nested list length doesn't match query_list_length."""
graph = setup_graph
nested_list = [
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
with pytest.raises(ValueError, match="test has 2 query lists, but query_list_length is 3"):
graph._normalize_query_distance_lists(nested_list, query_list_length=3, name="test")
def test_normalize_query_distance_lists_empty_list(setup_graph):
"""Test that empty list returns empty list."""
graph = setup_graph
result = graph._normalize_query_distance_lists([], query_list_length=None, name="test")
assert result == []

View file

@ -214,6 +214,39 @@ async def test_node_edge_vector_search_single_query_collection_not_found():
assert vector_search.node_distances["MissingCollection"] == []
@pytest.mark.asyncio
async def test_node_edge_vector_search_missing_collections_single_query():
"""Test that missing collections in single-query mode are handled gracefully with empty lists."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
node_result = MockScoredResult("node1", 0.95)
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [node_result]
elif collection_name == "MissingCollection":
raise CollectionNotFoundError("Collection not found")
return []
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
collections = ["Entity_name", "MissingCollection", "EmptyCollection"]
await vector_search.embed_and_retrieve_distances(
query="test query", query_batch=None, collections=collections, wide_search_limit=10
)
assert len(vector_search.node_distances["Entity_name"]) == 1
assert vector_search.node_distances["Entity_name"][0].id == "node1"
assert vector_search.node_distances["Entity_name"][0].score == 0.95
assert vector_search.node_distances["MissingCollection"] == []
assert vector_search.node_distances["EmptyCollection"] == []
@pytest.mark.asyncio
async def test_node_edge_vector_search_has_results_batch_nodes_only():
"""Test has_results returns True when only node distances are populated in batch mode."""