fix: resolve python 3.10 issue with mocking [COG-1274] (#517)

<!-- .github/pull_request_template.md -->

## Description
Resolve issue with patch failing in python3.10

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Tests**
- Refined internal test setups to enhance clarity and streamline
dependency injection.
- Improved organization of test cases to ensure robust verification of
behaviors and error handling.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Igor Ilic 2025-02-10 12:34:25 +01:00 committed by GitHub
parent 186b82c177
commit 591576b424
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,10 +1,12 @@
import pytest import pytest
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from cognee.tasks.storage.index_graph_edges import index_graph_edges
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_index_graph_edges_success(): async def test_index_graph_edges_success():
"""Test that index_graph_edges uses the index datapoints and creates vector index.""" """Test that index_graph_edges uses the index datapoints and creates vector index."""
# Create the mocks for the graph and vector engines.
mock_graph_engine = AsyncMock() mock_graph_engine = AsyncMock()
mock_graph_engine.get_graph_data.return_value = ( mock_graph_engine.get_graph_data.return_value = (
None, None,
@ -13,26 +15,25 @@ async def test_index_graph_edges_success():
[{"relationship_name": "rel2"}], [{"relationship_name": "rel2"}],
], ],
) )
mock_vector_engine = AsyncMock() mock_vector_engine = AsyncMock()
with ( # Patch the globals of the function so that when it does:
patch( # vector_engine = get_vector_engine()
"cognee.tasks.storage.index_graph_edges.get_graph_engine", # graph_engine = await get_graph_engine()
return_value=mock_graph_engine, # it uses the mocked versions.
), with patch.dict(
patch( index_graph_edges.__globals__,
"cognee.tasks.storage.index_graph_edges.get_vector_engine", {
return_value=mock_vector_engine, "get_graph_engine": AsyncMock(return_value=mock_graph_engine),
), "get_vector_engine": lambda: mock_vector_engine,
},
): ):
from cognee.tasks.storage.index_graph_edges import index_graph_edges
await index_graph_edges() await index_graph_edges()
mock_graph_engine.get_graph_data.assert_awaited_once() # Assertions on the mock calls.
assert mock_vector_engine.create_vector_index.await_count == 1 mock_graph_engine.get_graph_data.assert_awaited_once()
assert mock_vector_engine.index_data_points.await_count == 1 assert mock_vector_engine.create_vector_index.await_count == 1
assert mock_vector_engine.index_data_points.await_count == 1
@pytest.mark.asyncio @pytest.mark.asyncio
@ -40,39 +41,31 @@ async def test_index_graph_edges_no_relationships():
"""Test that index_graph_edges handles empty relationships correctly.""" """Test that index_graph_edges handles empty relationships correctly."""
mock_graph_engine = AsyncMock() mock_graph_engine = AsyncMock()
mock_graph_engine.get_graph_data.return_value = (None, []) mock_graph_engine.get_graph_data.return_value = (None, [])
mock_vector_engine = AsyncMock() mock_vector_engine = AsyncMock()
with ( with patch.dict(
patch( index_graph_edges.__globals__,
"cognee.tasks.storage.index_graph_edges.get_graph_engine", {
return_value=mock_graph_engine, "get_graph_engine": AsyncMock(return_value=mock_graph_engine),
), "get_vector_engine": lambda: mock_vector_engine,
patch( },
"cognee.tasks.storage.index_graph_edges.get_vector_engine",
return_value=mock_vector_engine,
),
): ):
from cognee.tasks.storage.index_graph_edges import index_graph_edges
await index_graph_edges() await index_graph_edges()
mock_graph_engine.get_graph_data.assert_awaited_once() mock_graph_engine.get_graph_data.assert_awaited_once()
mock_vector_engine.create_vector_index.assert_not_awaited() mock_vector_engine.create_vector_index.assert_not_awaited()
mock_vector_engine.index_data_points.assert_not_awaited() mock_vector_engine.index_data_points.assert_not_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_index_graph_edges_initialization_error(): async def test_index_graph_edges_initialization_error():
"""Test that index_graph_edges raises a RuntimeError if initialization fails.""" """Test that index_graph_edges raises a RuntimeError if initialization fails."""
with ( with patch.dict(
patch( index_graph_edges.__globals__,
"cognee.tasks.storage.index_graph_edges.get_graph_engine", {
side_effect=Exception("Graph engine failed"), "get_graph_engine": AsyncMock(side_effect=Exception("Graph engine failed")),
), "get_vector_engine": lambda: AsyncMock(),
patch("cognee.tasks.storage.index_graph_edges.get_vector_engine", return_value=AsyncMock()), },
): ):
from cognee.tasks.storage.index_graph_edges import index_graph_edges
with pytest.raises(RuntimeError, match="Initialization error"): with pytest.raises(RuntimeError, match="Initialization error"):
await index_graph_edges() await index_graph_edges()