Moving bfs depth unit tests to search utils tests file.
This commit is contained in:
parent
ecf9a75fbf
commit
f99ec00a0c
2 changed files with 80 additions and 95 deletions
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
|
|
||||||
from graphiti_core.nodes import EntityNode
|
from graphiti_core.nodes import EntityNode
|
||||||
from graphiti_core.search.search_filters import SearchFilters
|
from graphiti_core.search.search_filters import SearchFilters
|
||||||
from graphiti_core.search.search_utils import hybrid_node_search
|
from graphiti_core.search.search_utils import edge_bfs_search, hybrid_node_search, node_bfs_search
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -161,3 +161,82 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
|
||||||
mock_similarity_search.assert_called_with(
|
mock_similarity_search.assert_called_with(
|
||||||
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4
|
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_edge_bfs_search_uses_depth_parameter():
|
||||||
|
"""Test that edge_bfs_search uses the bfs_max_depth parameter in the query."""
|
||||||
|
# Mock driver
|
||||||
|
mock_driver = AsyncMock()
|
||||||
|
mock_driver.execute_query.return_value = ([], None, None)
|
||||||
|
|
||||||
|
# Mock search filter
|
||||||
|
search_filter = SearchFilters()
|
||||||
|
|
||||||
|
# Call edge_bfs_search with depth=2
|
||||||
|
await edge_bfs_search(
|
||||||
|
driver=mock_driver,
|
||||||
|
bfs_origin_node_uuids=['test-uuid'],
|
||||||
|
bfs_max_depth=2,
|
||||||
|
search_filter=search_filter,
|
||||||
|
group_ids=['test-group'],
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the query was called
|
||||||
|
assert mock_driver.execute_query.called
|
||||||
|
call_args = mock_driver.execute_query.call_args
|
||||||
|
|
||||||
|
# Check that the query contains the variable depth pattern
|
||||||
|
query = call_args.args[0]
|
||||||
|
assert '*1..2' in query, f"Query should contain '*1..2' but got: {query}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_node_bfs_search_uses_depth_parameter():
|
||||||
|
"""Test that node_bfs_search uses the bfs_max_depth parameter in the query."""
|
||||||
|
# Mock driver
|
||||||
|
mock_driver = AsyncMock()
|
||||||
|
mock_driver.execute_query.return_value = ([], None, None)
|
||||||
|
|
||||||
|
# Mock search filter
|
||||||
|
search_filter = SearchFilters()
|
||||||
|
|
||||||
|
# Call node_bfs_search with depth=1
|
||||||
|
await node_bfs_search(
|
||||||
|
driver=mock_driver,
|
||||||
|
bfs_origin_node_uuids=['test-uuid'],
|
||||||
|
search_filter=search_filter,
|
||||||
|
bfs_max_depth=1,
|
||||||
|
group_ids=['test-group'],
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the query was called
|
||||||
|
assert mock_driver.execute_query.called
|
||||||
|
call_args = mock_driver.execute_query.call_args
|
||||||
|
|
||||||
|
# Check that the query contains the variable depth pattern
|
||||||
|
query = call_args.args[0]
|
||||||
|
assert '*1..1' in query, f"Query should contain '*1..1' but got: {query}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_different_depth_values():
|
||||||
|
"""Test that different bfs_max_depth values are correctly passed."""
|
||||||
|
mock_driver = AsyncMock()
|
||||||
|
mock_driver.execute_query.return_value = ([], None, None)
|
||||||
|
search_filter = SearchFilters()
|
||||||
|
|
||||||
|
# Test depth=5
|
||||||
|
await edge_bfs_search(
|
||||||
|
driver=mock_driver,
|
||||||
|
bfs_origin_node_uuids=['test-uuid'],
|
||||||
|
bfs_max_depth=5,
|
||||||
|
search_filter=search_filter,
|
||||||
|
group_ids=['test-group'],
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_driver.execute_query.call_args
|
||||||
|
query = call_args.args[0]
|
||||||
|
assert '*1..5' in query, f"Query should contain '*1..5' but got: {query}"
|
||||||
|
|
|
||||||
|
|
@ -1,94 +0,0 @@
|
||||||
"""Test for BFS max_depth parameter fix."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from graphiti_core.search.search_filters import SearchFilters
|
|
||||||
from graphiti_core.search.search_utils import edge_bfs_search, node_bfs_search
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_edge_bfs_search_uses_depth_parameter():
|
|
||||||
"""Test that edge_bfs_search uses the bfs_max_depth parameter in the query."""
|
|
||||||
# Mock driver
|
|
||||||
mock_driver = AsyncMock()
|
|
||||||
mock_driver.execute_query.return_value = ([], None, None)
|
|
||||||
|
|
||||||
# Mock search filter
|
|
||||||
search_filter = SearchFilters()
|
|
||||||
|
|
||||||
# Call edge_bfs_search with depth=2
|
|
||||||
await edge_bfs_search(
|
|
||||||
driver=mock_driver,
|
|
||||||
bfs_origin_node_uuids=['test-uuid'],
|
|
||||||
bfs_max_depth=2,
|
|
||||||
search_filter=search_filter,
|
|
||||||
group_ids=['test-group'],
|
|
||||||
limit=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the query was called
|
|
||||||
assert mock_driver.execute_query.called
|
|
||||||
call_args = mock_driver.execute_query.call_args
|
|
||||||
|
|
||||||
# Check that depth parameter is passed
|
|
||||||
assert 'depth' in call_args.kwargs
|
|
||||||
assert call_args.kwargs['depth'] == 2
|
|
||||||
|
|
||||||
# Check that the query contains the variable depth pattern
|
|
||||||
query = call_args.args[0]
|
|
||||||
assert '{1,$depth}' in query, f"Query should contain '{{1,$depth}}' but got: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_node_bfs_search_uses_depth_parameter():
|
|
||||||
"""Test that node_bfs_search uses the bfs_max_depth parameter in the query."""
|
|
||||||
# Mock driver
|
|
||||||
mock_driver = AsyncMock()
|
|
||||||
mock_driver.execute_query.return_value = ([], None, None)
|
|
||||||
|
|
||||||
# Mock search filter
|
|
||||||
search_filter = SearchFilters()
|
|
||||||
|
|
||||||
# Call node_bfs_search with depth=1
|
|
||||||
await node_bfs_search(
|
|
||||||
driver=mock_driver,
|
|
||||||
bfs_origin_node_uuids=['test-uuid'],
|
|
||||||
search_filter=search_filter,
|
|
||||||
bfs_max_depth=1,
|
|
||||||
group_ids=['test-group'],
|
|
||||||
limit=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the query was called
|
|
||||||
assert mock_driver.execute_query.called
|
|
||||||
call_args = mock_driver.execute_query.call_args
|
|
||||||
|
|
||||||
# Check that depth parameter is passed
|
|
||||||
assert 'depth' in call_args.kwargs
|
|
||||||
assert call_args.kwargs['depth'] == 1
|
|
||||||
|
|
||||||
# Check that the query contains the variable depth pattern
|
|
||||||
query = call_args.args[0]
|
|
||||||
assert '{1,$depth}' in query, f"Query should contain '{{1,$depth}}' but got: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_different_depth_values():
|
|
||||||
"""Test that different bfs_max_depth values are correctly passed."""
|
|
||||||
mock_driver = AsyncMock()
|
|
||||||
mock_driver.execute_query.return_value = ([], None, None)
|
|
||||||
search_filter = SearchFilters()
|
|
||||||
|
|
||||||
# Test depth=5
|
|
||||||
await edge_bfs_search(
|
|
||||||
driver=mock_driver,
|
|
||||||
bfs_origin_node_uuids=['test-uuid'],
|
|
||||||
bfs_max_depth=5,
|
|
||||||
search_filter=search_filter,
|
|
||||||
group_ids=['test-group'],
|
|
||||||
)
|
|
||||||
|
|
||||||
call_args = mock_driver.execute_query.call_args
|
|
||||||
assert call_args.kwargs['depth'] == 5
|
|
||||||
Loading…
Add table
Reference in a new issue