diff --git a/tests/utils/search/search_utils_test.py b/tests/utils/search/search_utils_test.py index 6b97daab..47f91c6c 100644 --- a/tests/utils/search/search_utils_test.py +++ b/tests/utils/search/search_utils_test.py @@ -4,7 +4,7 @@ import pytest from graphiti_core.nodes import EntityNode from graphiti_core.search.search_filters import SearchFilters -from graphiti_core.search.search_utils import hybrid_node_search +from graphiti_core.search.search_utils import edge_bfs_search, hybrid_node_search, node_bfs_search @pytest.mark.asyncio @@ -161,3 +161,82 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): mock_similarity_search.assert_called_with( mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4 ) + + +@pytest.mark.asyncio +async def test_edge_bfs_search_uses_depth_parameter(): + """Test that edge_bfs_search uses the bfs_max_depth parameter in the query.""" + # Mock driver + mock_driver = AsyncMock() + mock_driver.execute_query.return_value = ([], None, None) + + # Mock search filter + search_filter = SearchFilters() + + # Call edge_bfs_search with depth=2 + await edge_bfs_search( + driver=mock_driver, + bfs_origin_node_uuids=['test-uuid'], + bfs_max_depth=2, + search_filter=search_filter, + group_ids=['test-group'], + limit=10, + ) + + # Verify the query was called + assert mock_driver.execute_query.called + call_args = mock_driver.execute_query.call_args + + # Check that the query contains the variable depth pattern + query = call_args.args[0] + assert '*1..2' in query, f"Query should contain '*1..2' but got: {query}" + + +@pytest.mark.asyncio +async def test_node_bfs_search_uses_depth_parameter(): + """Test that node_bfs_search uses the bfs_max_depth parameter in the query.""" + # Mock driver + mock_driver = AsyncMock() + mock_driver.execute_query.return_value = ([], None, None) + + # Mock search filter + search_filter = SearchFilters() + + # Call node_bfs_search with depth=1 + await node_bfs_search( + driver=mock_driver, + bfs_origin_node_uuids=['test-uuid'], + search_filter=search_filter, + bfs_max_depth=1, + group_ids=['test-group'], + limit=10, + ) + + # Verify the query was called + assert mock_driver.execute_query.called + call_args = mock_driver.execute_query.call_args + + # Check that the query contains the variable depth pattern + query = call_args.args[0] + assert '*1..1' in query, f"Query should contain '*1..1' but got: {query}" + + +@pytest.mark.asyncio +async def test_different_depth_values(): + """Test that different bfs_max_depth values are correctly passed.""" + mock_driver = AsyncMock() + mock_driver.execute_query.return_value = ([], None, None) + search_filter = SearchFilters() + + # Test depth=5 + await edge_bfs_search( + driver=mock_driver, + bfs_origin_node_uuids=['test-uuid'], + bfs_max_depth=5, + search_filter=search_filter, + group_ids=['test-group'], + ) + + call_args = mock_driver.execute_query.call_args + query = call_args.args[0] + assert '*1..5' in query, f"Query should contain '*1..5' but got: {query}"