tests: update and expand test_brute_force_triplet_search.py and test_node_edge_vector_search.py

This commit is contained in:
lxobr 2026-01-12 14:40:55 +01:00
parent 7833189001
commit c20304a92a
2 changed files with 264 additions and 2 deletions

View file

@ -30,7 +30,7 @@ async def test_brute_force_triplet_search_empty_query():
@pytest.mark.asyncio
async def test_brute_force_triplet_search_none_query():
"""Test that None query raises ValueError."""
with pytest.raises(ValueError, match="The query must be a non-empty string."):
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'."):
await brute_force_triplet_search(query=None)
@ -351,7 +351,9 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
custom_top_k = 15
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(
k=custom_top_k, query_list_length=None
)
@pytest.mark.asyncio
@ -815,3 +817,237 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level():
result = await brute_force_triplet_search(query="test query")
assert result == []
@pytest.mark.asyncio
async def test_brute_force_triplet_search_single_query_regression():
"""Test that single-query mode maintains legacy behavior (flat list, ID filtering)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("node1", 0.95)])
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
result = await brute_force_triplet_search(
query="q1", query_batch=None, wide_search_top_k=10, node_name=None
)
assert isinstance(result, list)
assert not (result and isinstance(result[0], list))
mock_get_fragment.assert_called_once()
call_kwargs = mock_get_fragment.call_args[1]
assert call_kwargs["relevant_ids_to_filter"] is not None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_batch_wiring_happy_path():
"""Test that batch mode returns list-of-lists and skips ID filtering."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.batch_search = AsyncMock(
return_value=[
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
)
with (
patch(
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
assert isinstance(result, list)
assert len(result) == 2
assert isinstance(result[0], list)
assert isinstance(result[1], list)
mock_get_fragment.assert_called_once()
call_kwargs = mock_get_fragment.call_args[1]
assert call_kwargs["relevant_ids_to_filter"] is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_shape_propagation_to_graph():
"""Test that query_list_length is passed through to graph mapping methods."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.batch_search = AsyncMock(
return_value=[
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
)
with (
patch(
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
),
):
await brute_force_triplet_search(query_batch=["q1", "q2"])
mock_fragment.map_vector_distances_to_graph_nodes.assert_called_once()
node_call_kwargs = mock_fragment.map_vector_distances_to_graph_nodes.call_args[1]
assert "query_list_length" in node_call_kwargs
assert node_call_kwargs["query_list_length"] == 2
mock_fragment.map_vector_distances_to_graph_edges.assert_called_once()
edge_call_kwargs = mock_fragment.map_vector_distances_to_graph_edges.call_args[1]
assert "query_list_length" in edge_call_kwargs
assert edge_call_kwargs["query_list_length"] == 2
mock_fragment.calculate_top_triplet_importances.assert_called_once()
importance_call_kwargs = mock_fragment.calculate_top_triplet_importances.call_args[1]
assert "query_list_length" in importance_call_kwargs
assert importance_call_kwargs["query_list_length"] == 2
@pytest.mark.asyncio
async def test_brute_force_triplet_search_batch_path_comprehensive():
"""Test batch mode: returns list-of-lists, skips ID filtering, passes None for wide_search_limit."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
def batch_search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
elif collection_name == "EdgeType_relationship_name":
return [
[MockScoredResult("edge1", 0.92)],
[MockScoredResult("edge2", 0.88)],
]
return [[], []]
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
)
with (
patch(
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
result = await brute_force_triplet_search(
query_batch=["q1", "q2"], collections=["Entity_name", "EdgeType_relationship_name"]
)
assert isinstance(result, list)
assert len(result) == 2
assert isinstance(result[0], list)
assert isinstance(result[1], list)
mock_get_fragment.assert_called_once()
fragment_call_kwargs = mock_get_fragment.call_args[1]
assert fragment_call_kwargs["relevant_ids_to_filter"] is None
batch_search_calls = mock_vector_engine.batch_search.call_args_list
assert len(batch_search_calls) > 0
for call in batch_search_calls:
assert call[1]["limit"] is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_batch_error_fallback():
"""Test that CollectionNotFoundError in batch mode returns [[], []] matching batch length."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.batch_search = AsyncMock(
side_effect=CollectionNotFoundError("Collection not found")
)
with patch(
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
return_value=mock_vector_engine,
):
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
assert result == [[], []]
assert len(result) == 2
@pytest.mark.asyncio
async def test_cognee_graph_mapping_batch_shapes():
"""Test that CogneeGraph mapping methods accept list-of-lists with query_list_length set."""
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
graph = CogneeGraph()
node1 = Node("node1", {"name": "Node1"})
node2 = Node("node2", {"name": "Node2"})
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(node1, node2, attributes={"edge_text": "relates_to"})
graph.add_edge(edge)
node_distances_batch = {
"Entity_name": [
[MockScoredResult("node1", 0.95)],
[MockScoredResult("node2", 0.87)],
]
}
edge_distances_batch = [
[MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})],
[MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})],
]
await graph.map_vector_distances_to_graph_nodes(
node_distances=node_distances_batch, query_list_length=2
)
await graph.map_vector_distances_to_graph_edges(
edge_distances=edge_distances_batch, query_list_length=2
)
assert node1.attributes.get("vector_distance") == [0.95, 3.5]
assert node2.attributes.get("vector_distance") == [3.5, 0.87]
assert edge.attributes.get("vector_distance") == [0.92, 0.88]

View file

@ -212,3 +212,29 @@ async def test_node_edge_vector_search_single_query_collection_not_found():
)
assert vector_search.node_distances["MissingCollection"] == []
@pytest.mark.asyncio
async def test_node_edge_vector_search_has_results_batch_nodes_only():
"""Test has_results returns True when only node distances are populated in batch mode."""
mock_vector_engine = AsyncMock()
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
vector_search.query_list_length = 2
vector_search.edge_distances = [[], []]
vector_search.node_distances = {
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
}
assert vector_search.has_results() is True
@pytest.mark.asyncio
async def test_node_edge_vector_search_has_results_batch_edges_only():
"""Test has_results returns True when only edge distances are populated in batch mode."""
mock_vector_engine = AsyncMock()
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
vector_search.query_list_length = 2
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
vector_search.node_distances = {}
assert vector_search.has_results() is True