diff --git a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py index 59ad606b0..9cd96f5b9 100644 --- a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +++ b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py @@ -1,10 +1,12 @@ import pytest from unittest.mock import AsyncMock, patch +from cognee.tasks.storage.index_graph_edges import index_graph_edges @pytest.mark.asyncio async def test_index_graph_edges_success(): """Test that index_graph_edges uses the index datapoints and creates vector index.""" + # Create the mocks for the graph and vector engines. mock_graph_engine = AsyncMock() mock_graph_engine.get_graph_data.return_value = ( None, @@ -13,26 +15,25 @@ async def test_index_graph_edges_success(): [{"relationship_name": "rel2"}], ], ) - mock_vector_engine = AsyncMock() - with ( - patch( - "cognee.tasks.storage.index_graph_edges.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.tasks.storage.index_graph_edges.get_vector_engine", - return_value=mock_vector_engine, - ), + # Patch the globals of the function so that when it does: + # vector_engine = get_vector_engine() + # graph_engine = await get_graph_engine() + # it uses the mocked versions. + with patch.dict( + index_graph_edges.__globals__, + { + "get_graph_engine": AsyncMock(return_value=mock_graph_engine), + "get_vector_engine": lambda: mock_vector_engine, + }, ): - from cognee.tasks.storage.index_graph_edges import index_graph_edges - await index_graph_edges() - mock_graph_engine.get_graph_data.assert_awaited_once() - assert mock_vector_engine.create_vector_index.await_count == 1 - assert mock_vector_engine.index_data_points.await_count == 1 + # Assertions on the mock calls. + mock_graph_engine.get_graph_data.assert_awaited_once() + assert mock_vector_engine.create_vector_index.await_count == 1 + assert mock_vector_engine.index_data_points.await_count == 1 @pytest.mark.asyncio @@ -40,39 +41,31 @@ async def test_index_graph_edges_no_relationships(): """Test that index_graph_edges handles empty relationships correctly.""" mock_graph_engine = AsyncMock() mock_graph_engine.get_graph_data.return_value = (None, []) - mock_vector_engine = AsyncMock() - with ( - patch( - "cognee.tasks.storage.index_graph_edges.get_graph_engine", - return_value=mock_graph_engine, - ), - patch( - "cognee.tasks.storage.index_graph_edges.get_vector_engine", - return_value=mock_vector_engine, - ), + with patch.dict( + index_graph_edges.__globals__, + { + "get_graph_engine": AsyncMock(return_value=mock_graph_engine), + "get_vector_engine": lambda: mock_vector_engine, + }, ): - from cognee.tasks.storage.index_graph_edges import index_graph_edges - await index_graph_edges() - mock_graph_engine.get_graph_data.assert_awaited_once() - mock_vector_engine.create_vector_index.assert_not_awaited() - mock_vector_engine.index_data_points.assert_not_awaited() + mock_graph_engine.get_graph_data.assert_awaited_once() + mock_vector_engine.create_vector_index.assert_not_awaited() + mock_vector_engine.index_data_points.assert_not_awaited() @pytest.mark.asyncio async def test_index_graph_edges_initialization_error(): """Test that index_graph_edges raises a RuntimeError if initialization fails.""" - with ( - patch( - "cognee.tasks.storage.index_graph_edges.get_graph_engine", - side_effect=Exception("Graph engine failed"), - ), - patch("cognee.tasks.storage.index_graph_edges.get_vector_engine", return_value=AsyncMock()), + with patch.dict( + index_graph_edges.__globals__, + { + "get_graph_engine": AsyncMock(side_effect=Exception("Graph engine failed")), + "get_vector_engine": lambda: AsyncMock(), + }, ): - from cognee.tasks.storage.index_graph_edges import index_graph_edges - with pytest.raises(RuntimeError, match="Initialization error"): await index_graph_edges()