fix: resolve python 3.10 issue with mocking [COG-1274] (#517)
<!-- .github/pull_request_template.md --> ## Description Resolve issue with patch failing in python3.10 ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Tests** - Refined internal test setups to enhance clarity and streamline dependency injection. - Improved organization of test cases to ensure robust verification of behaviors and error handling. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
186b82c177
commit
591576b424
1 changed files with 31 additions and 38 deletions
|
|
@ -1,10 +1,12 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_graph_edges_success():
|
||||
"""Test that index_graph_edges uses the index datapoints and creates vector index."""
|
||||
# Create the mocks for the graph and vector engines.
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.get_graph_data.return_value = (
|
||||
None,
|
||||
|
|
@ -13,26 +15,25 @@ async def test_index_graph_edges_success():
|
|||
[{"relationship_name": "rel2"}],
|
||||
],
|
||||
)
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.tasks.storage.index_graph_edges.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.tasks.storage.index_graph_edges.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
# Patch the globals of the function so that when it does:
|
||||
# vector_engine = get_vector_engine()
|
||||
# graph_engine = await get_graph_engine()
|
||||
# it uses the mocked versions.
|
||||
with patch.dict(
|
||||
index_graph_edges.__globals__,
|
||||
{
|
||||
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
|
||||
"get_vector_engine": lambda: mock_vector_engine,
|
||||
},
|
||||
):
|
||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||
|
||||
await index_graph_edges()
|
||||
|
||||
mock_graph_engine.get_graph_data.assert_awaited_once()
|
||||
assert mock_vector_engine.create_vector_index.await_count == 1
|
||||
assert mock_vector_engine.index_data_points.await_count == 1
|
||||
# Assertions on the mock calls.
|
||||
mock_graph_engine.get_graph_data.assert_awaited_once()
|
||||
assert mock_vector_engine.create_vector_index.await_count == 1
|
||||
assert mock_vector_engine.index_data_points.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -40,39 +41,31 @@ async def test_index_graph_edges_no_relationships():
|
|||
"""Test that index_graph_edges handles empty relationships correctly."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.get_graph_data.return_value = (None, [])
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.tasks.storage.index_graph_edges.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.tasks.storage.index_graph_edges.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
with patch.dict(
|
||||
index_graph_edges.__globals__,
|
||||
{
|
||||
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
|
||||
"get_vector_engine": lambda: mock_vector_engine,
|
||||
},
|
||||
):
|
||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||
|
||||
await index_graph_edges()
|
||||
|
||||
mock_graph_engine.get_graph_data.assert_awaited_once()
|
||||
mock_vector_engine.create_vector_index.assert_not_awaited()
|
||||
mock_vector_engine.index_data_points.assert_not_awaited()
|
||||
mock_graph_engine.get_graph_data.assert_awaited_once()
|
||||
mock_vector_engine.create_vector_index.assert_not_awaited()
|
||||
mock_vector_engine.index_data_points.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_graph_edges_initialization_error():
|
||||
"""Test that index_graph_edges raises a RuntimeError if initialization fails."""
|
||||
with (
|
||||
patch(
|
||||
"cognee.tasks.storage.index_graph_edges.get_graph_engine",
|
||||
side_effect=Exception("Graph engine failed"),
|
||||
),
|
||||
patch("cognee.tasks.storage.index_graph_edges.get_vector_engine", return_value=AsyncMock()),
|
||||
with patch.dict(
|
||||
index_graph_edges.__globals__,
|
||||
{
|
||||
"get_graph_engine": AsyncMock(side_effect=Exception("Graph engine failed")),
|
||||
"get_vector_engine": lambda: AsyncMock(),
|
||||
},
|
||||
):
|
||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||
|
||||
with pytest.raises(RuntimeError, match="Initialization error"):
|
||||
await index_graph_edges()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue