tests: update and expand test_brute_force_triplet_search.py and test_node_edge_vector_search.py
This commit is contained in:
parent
7833189001
commit
c20304a92a
2 changed files with 264 additions and 2 deletions
|
|
@ -30,7 +30,7 @@ async def test_brute_force_triplet_search_empty_query():
|
|||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_none_query():
|
||||
"""Test that None query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
with pytest.raises(ValueError, match="Must provide either 'query' or 'query_batch'."):
|
||||
await brute_force_triplet_search(query=None)
|
||||
|
||||
|
||||
|
|
@ -351,7 +351,9 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
|
|||
custom_top_k = 15
|
||||
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(
|
||||
k=custom_top_k, query_list_length=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -815,3 +817,237 @@ async def test_brute_force_triplet_search_collection_not_found_at_top_level():
|
|||
result = await brute_force_triplet_search(query="test query")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_single_query_regression():
|
||||
"""Test that single-query mode maintains legacy behavior (flat list, ID filtering)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("node1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
result = await brute_force_triplet_search(
|
||||
query="q1", query_batch=None, wide_search_top_k=10, node_name=None
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert not (result and isinstance(result[0], list))
|
||||
mock_get_fragment.assert_called_once()
|
||||
call_kwargs = mock_get_fragment.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_batch_wiring_happy_path():
|
||||
"""Test that batch mode returns list-of-lists and skips ID filtering."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.batch_search = AsyncMock(
|
||||
return_value=[
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], list)
|
||||
assert isinstance(result[1], list)
|
||||
mock_get_fragment.assert_called_once()
|
||||
call_kwargs = mock_get_fragment.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_shape_propagation_to_graph():
|
||||
"""Test that query_list_length is passed through to graph mapping methods."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.batch_search = AsyncMock(
|
||||
return_value=[
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
await brute_force_triplet_search(query_batch=["q1", "q2"])
|
||||
|
||||
mock_fragment.map_vector_distances_to_graph_nodes.assert_called_once()
|
||||
node_call_kwargs = mock_fragment.map_vector_distances_to_graph_nodes.call_args[1]
|
||||
assert "query_list_length" in node_call_kwargs
|
||||
assert node_call_kwargs["query_list_length"] == 2
|
||||
|
||||
mock_fragment.map_vector_distances_to_graph_edges.assert_called_once()
|
||||
edge_call_kwargs = mock_fragment.map_vector_distances_to_graph_edges.call_args[1]
|
||||
assert "query_list_length" in edge_call_kwargs
|
||||
assert edge_call_kwargs["query_list_length"] == 2
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once()
|
||||
importance_call_kwargs = mock_fragment.calculate_top_triplet_importances.call_args[1]
|
||||
assert "query_list_length" in importance_call_kwargs
|
||||
assert importance_call_kwargs["query_list_length"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_batch_path_comprehensive():
|
||||
"""Test batch mode: returns list-of-lists, skips ID filtering, passes None for wide_search_limit."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
|
||||
def batch_search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
elif collection_name == "EdgeType_relationship_name":
|
||||
return [
|
||||
[MockScoredResult("edge1", 0.92)],
|
||||
[MockScoredResult("edge2", 0.88)],
|
||||
]
|
||||
return [[], []]
|
||||
|
||||
mock_vector_engine.batch_search = AsyncMock(side_effect=batch_search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[[], []]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
result = await brute_force_triplet_search(
|
||||
query_batch=["q1", "q2"], collections=["Entity_name", "EdgeType_relationship_name"]
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], list)
|
||||
assert isinstance(result[1], list)
|
||||
|
||||
mock_get_fragment.assert_called_once()
|
||||
fragment_call_kwargs = mock_get_fragment.call_args[1]
|
||||
assert fragment_call_kwargs["relevant_ids_to_filter"] is None
|
||||
|
||||
batch_search_calls = mock_vector_engine.batch_search.call_args_list
|
||||
assert len(batch_search_calls) > 0
|
||||
for call in batch_search_calls:
|
||||
assert call[1]["limit"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_batch_error_fallback():
|
||||
"""Test that CollectionNotFoundError in batch mode returns [[], []] matching batch length."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.batch_search = AsyncMock(
|
||||
side_effect=CollectionNotFoundError("Collection not found")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.node_edge_vector_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
result = await brute_force_triplet_search(query_batch=["q1", "q2"])
|
||||
|
||||
assert result == [[], []]
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cognee_graph_mapping_batch_shapes():
|
||||
"""Test that CogneeGraph mapping methods accept list-of-lists with query_list_length set."""
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||
|
||||
graph = CogneeGraph()
|
||||
node1 = Node("node1", {"name": "Node1"})
|
||||
node2 = Node("node2", {"name": "Node2"})
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(node1, node2, attributes={"edge_text": "relates_to"})
|
||||
graph.add_edge(edge)
|
||||
|
||||
node_distances_batch = {
|
||||
"Entity_name": [
|
||||
[MockScoredResult("node1", 0.95)],
|
||||
[MockScoredResult("node2", 0.87)],
|
||||
]
|
||||
}
|
||||
|
||||
edge_distances_batch = [
|
||||
[MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})],
|
||||
[MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})],
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(
|
||||
node_distances=node_distances_batch, query_list_length=2
|
||||
)
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
edge_distances=edge_distances_batch, query_list_length=2
|
||||
)
|
||||
|
||||
assert node1.attributes.get("vector_distance") == [0.95, 3.5]
|
||||
assert node2.attributes.get("vector_distance") == [3.5, 0.87]
|
||||
assert edge.attributes.get("vector_distance") == [0.92, 0.88]
|
||||
|
|
|
|||
|
|
@ -212,3 +212,29 @@ async def test_node_edge_vector_search_single_query_collection_not_found():
|
|||
)
|
||||
|
||||
assert vector_search.node_distances["MissingCollection"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_batch_nodes_only():
|
||||
"""Test has_results returns True when only node distances are populated in batch mode."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = 2
|
||||
vector_search.edge_distances = [[], []]
|
||||
vector_search.node_distances = {
|
||||
"Entity_name": [[MockScoredResult("node1", 0.95)], []],
|
||||
}
|
||||
|
||||
assert vector_search.has_results() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_edge_vector_search_has_results_batch_edges_only():
|
||||
"""Test has_results returns True when only edge distances are populated in batch mode."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
vector_search = NodeEdgeVectorSearch(vector_engine=mock_vector_engine)
|
||||
vector_search.query_list_length = 2
|
||||
vector_search.edge_distances = [[MockScoredResult("edge1", 0.92)], []]
|
||||
vector_search.node_distances = {}
|
||||
|
||||
assert vector_search.has_results() is True
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue