chore: update tests and minor tweaks
This commit is contained in:
parent
c20304a92a
commit
1c8d0f6da1
3 changed files with 85 additions and 1 deletions
|
|
@ -49,7 +49,12 @@ class NodeEdgeVectorSearch:
|
|||
search_results: List[List[Any]],
|
||||
query_list_length: Optional[int] = None,
|
||||
):
|
||||
"""Separates search results into node and edge distances with stable shapes."""
|
||||
"""Separates search results into node and edge distances with stable shapes.
|
||||
|
||||
Ensures all collections are present in the output, even if empty:
|
||||
- Batch mode: missing/empty collections become [[]] * query_list_length
|
||||
- Single mode: missing/empty collections become []
|
||||
"""
|
||||
self.node_distances = {}
|
||||
self.edge_distances = (
|
||||
[] if query_list_length is None else [[] for _ in range(query_list_length)]
|
||||
|
|
|
|||
|
|
@ -718,3 +718,49 @@ async def test_calculate_top_triplet_importances_raises_on_missing_attribute(set
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
await graph.calculate_top_triplet_importances(k=1, query_list_length=1)
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_flat_list_single_query(setup_graph):
|
||||
"""Test that flat list is normalized to list-of-lists with length 1 for single-query mode."""
|
||||
graph = setup_graph
|
||||
flat_list = [MockScoredResult("node1", 0.95), MockScoredResult("node2", 0.87)]
|
||||
|
||||
result = graph._normalize_query_distance_lists(flat_list, query_list_length=None, name="test")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == flat_list
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_nested_list_batch_mode(setup_graph):
|
||||
"""Test that nested list is used as-is when query_list_length matches."""
|
||||
graph = setup_graph
|
||||
nested_list = [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
|
||||
result = graph._normalize_query_distance_lists(nested_list, query_list_length=2, name="test")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == nested_list
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_raises_on_length_mismatch(setup_graph):
|
||||
"""Test that ValueError is raised when nested list length doesn't match query_list_length."""
|
||||
graph = setup_graph
|
||||
nested_list = [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="test has 2 query lists, but query_list_length is 3"):
|
||||
graph._normalize_query_distance_lists(nested_list, query_list_length=3, name="test")
|
||||
|
||||
|
||||
def test_normalize_query_distance_lists_empty_list(setup_graph):
|
||||
"""Test that empty list returns empty list."""
|
||||
graph = setup_graph
|
||||
|
||||
result = graph._normalize_query_distance_lists([], query_list_length=None, name="test")
|
||||
|
||||
assert result == []
|
||||
|
|
|
|||
|
|
@ -214,6 +214,39 @@ async def test_node_edge_vector_search_single_query_collection_not_found():
|
|||
assert vector_search.node_distances["MissingCollection"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_missing_collections_single_query():
|
||||
"""Test that missing collections in single-query mode are handled gracefully with empty lists."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
node_result = MockScoredResult("node1", 0.95)
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [node_result]
|
||||
elif collection_name == "MissingCollection":
|
||||
raise CollectionNotFoundError("Collection not found")
|
||||
return []
|
||||
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
collections = ["Entity_name", "MissingCollection", "EmptyCollection"]
|
||||
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query="test query", query_batch=None, collections=collections, wide_search_limit=10
|
||||
)
|
||||
|
||||
assert len(vector_search.node_distances["Entity_name"]) == 1
|
||||
assert vector_search.node_distances["Entity_name"][0].id == "node1"
|
||||
assert vector_search.node_distances["Entity_name"][0].score == 0.95
|
||||
assert vector_search.node_distances["MissingCollection"] == []
|
||||
assert vector_search.node_distances["EmptyCollection"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_batch_nodes_only():
|
||||
"""Test has_results returns True when only node distances are populated in batch mode."""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue