fix: BFS max_depth parameter now properly controls traversal depth
- Fixed hardcoded {1,3} path lengths in both edge_bfs_search and node_bfs_search
- Replaced with {1,} to use the bfs_max_depth parameter correctly
- Added tests to verify the fix works
- Resolves issue where bfs_max_depth=1 would still traverse 3 hops
Closes #772
This commit is contained in:
parent
9ceeb54186
commit
2d085f61bb
2 changed files with 96 additions and 2 deletions
|
|
@ -295,7 +295,7 @@ async def edge_bfs_search(
|
|||
query = (
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,$depth}(n:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
WHERE r.uuid = rel.uuid
|
||||
|
|
@ -446,7 +446,7 @@ async def node_bfs_search(
|
|||
query = (
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,$depth}(n:Entity)
|
||||
WHERE n.group_id = origin.group_id
|
||||
AND origin.group_id IN $group_ids
|
||||
"""
|
||||
|
|
|
|||
94
tests/utils/search/test_bfs_depth_fix.py
Normal file
94
tests/utils/search/test_bfs_depth_fix.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
"""Test for BFS max_depth parameter fix."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import edge_bfs_search, node_bfs_search
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edge_bfs_search_uses_depth_parameter():
|
||||
"""Test that edge_bfs_search uses the bfs_max_depth parameter in the query."""
|
||||
# Mock driver
|
||||
mock_driver = AsyncMock()
|
||||
mock_driver.execute_query.return_value = ([], None, None)
|
||||
|
||||
# Mock search filter
|
||||
search_filter = SearchFilters()
|
||||
|
||||
# Call edge_bfs_search with depth=2
|
||||
await edge_bfs_search(
|
||||
driver=mock_driver,
|
||||
bfs_origin_node_uuids=['test-uuid'],
|
||||
bfs_max_depth=2,
|
||||
search_filter=search_filter,
|
||||
group_ids=['test-group'],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_driver.execute_query.called
|
||||
call_args = mock_driver.execute_query.call_args
|
||||
|
||||
# Check that depth parameter is passed
|
||||
assert 'depth' in call_args.kwargs
|
||||
assert call_args.kwargs['depth'] == 2
|
||||
|
||||
# Check that the query contains the variable depth pattern
|
||||
query = call_args.args[0]
|
||||
assert '{1,$depth}' in query, f"Query should contain '{{1,$depth}}' but got: {query}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_bfs_search_uses_depth_parameter():
|
||||
"""Test that node_bfs_search uses the bfs_max_depth parameter in the query."""
|
||||
# Mock driver
|
||||
mock_driver = AsyncMock()
|
||||
mock_driver.execute_query.return_value = ([], None, None)
|
||||
|
||||
# Mock search filter
|
||||
search_filter = SearchFilters()
|
||||
|
||||
# Call node_bfs_search with depth=1
|
||||
await node_bfs_search(
|
||||
driver=mock_driver,
|
||||
bfs_origin_node_uuids=['test-uuid'],
|
||||
search_filter=search_filter,
|
||||
bfs_max_depth=1,
|
||||
group_ids=['test-group'],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_driver.execute_query.called
|
||||
call_args = mock_driver.execute_query.call_args
|
||||
|
||||
# Check that depth parameter is passed
|
||||
assert 'depth' in call_args.kwargs
|
||||
assert call_args.kwargs['depth'] == 1
|
||||
|
||||
# Check that the query contains the variable depth pattern
|
||||
query = call_args.args[0]
|
||||
assert '{1,$depth}' in query, f"Query should contain '{{1,$depth}}' but got: {query}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_depth_values():
|
||||
"""Test that different bfs_max_depth values are correctly passed."""
|
||||
mock_driver = AsyncMock()
|
||||
mock_driver.execute_query.return_value = ([], None, None)
|
||||
search_filter = SearchFilters()
|
||||
|
||||
# Test depth=5
|
||||
await edge_bfs_search(
|
||||
driver=mock_driver,
|
||||
bfs_origin_node_uuids=['test-uuid'],
|
||||
bfs_max_depth=5,
|
||||
search_filter=search_filter,
|
||||
group_ids=['test-group'],
|
||||
)
|
||||
|
||||
call_args = mock_driver.execute_query.call_args
|
||||
assert call_args.kwargs['depth'] == 5
|
||||
Loading…
Add table
Reference in a new issue