This commit is contained in:
Pavel Jakovlev 2025-11-27 02:19:13 +01:00 committed by GitHub
commit b7c55c48e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,7 +4,7 @@ import pytest
from graphiti_core.nodes import EntityNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import hybrid_node_search
from graphiti_core.search.search_utils import edge_bfs_search, hybrid_node_search, node_bfs_search
@pytest.mark.asyncio
@ -161,3 +161,82 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
mock_similarity_search.assert_called_with(
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4
)
@pytest.mark.asyncio
async def test_edge_bfs_search_uses_depth_parameter():
"""Test that edge_bfs_search uses the bfs_max_depth parameter in the query."""
# Mock driver
mock_driver = AsyncMock()
mock_driver.execute_query.return_value = ([], None, None)
# Mock search filter
search_filter = SearchFilters()
# Call edge_bfs_search with depth=2
await edge_bfs_search(
driver=mock_driver,
bfs_origin_node_uuids=['test-uuid'],
bfs_max_depth=2,
search_filter=search_filter,
group_ids=['test-group'],
limit=10,
)
# Verify the query was called
assert mock_driver.execute_query.called
call_args = mock_driver.execute_query.call_args
# Check that the query contains the variable depth pattern
query = call_args.args[0]
assert '*1..2' in query, f"Query should contain '*1..2' but got: {query}"
@pytest.mark.asyncio
async def test_node_bfs_search_uses_depth_parameter():
"""Test that node_bfs_search uses the bfs_max_depth parameter in the query."""
# Mock driver
mock_driver = AsyncMock()
mock_driver.execute_query.return_value = ([], None, None)
# Mock search filter
search_filter = SearchFilters()
# Call node_bfs_search with depth=1
await node_bfs_search(
driver=mock_driver,
bfs_origin_node_uuids=['test-uuid'],
search_filter=search_filter,
bfs_max_depth=1,
group_ids=['test-group'],
limit=10,
)
# Verify the query was called
assert mock_driver.execute_query.called
call_args = mock_driver.execute_query.call_args
# Check that the query contains the variable depth pattern
query = call_args.args[0]
assert '*1..1' in query, f"Query should contain '*1..1' but got: {query}"
@pytest.mark.asyncio
async def test_different_depth_values():
"""Test that different bfs_max_depth values are correctly passed."""
mock_driver = AsyncMock()
mock_driver.execute_query.return_value = ([], None, None)
search_filter = SearchFilters()
# Test depth=5
await edge_bfs_search(
driver=mock_driver,
bfs_origin_node_uuids=['test-uuid'],
bfs_max_depth=5,
search_filter=search_filter,
group_ids=['test-group'],
)
call_args = mock_driver.execute_query.call_args
query = call_args.args[0]
assert '*1..5' in query, f"Query should contain '*1..5' but got: {query}"