adds 100% coverage to brute force triplet search
This commit is contained in:
parent
c2e402dbab
commit
22af6a050a
1 changed files with 218 additions and 9 deletions
|
|
@ -1,12 +1,14 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||
brute_force_triplet_search,
|
||||
get_memory_fragment,
|
||||
format_triplets,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
|
|
@ -354,20 +356,30 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
|
||||
"""Test that get_memory_fragment returns empty graph when entity not found."""
|
||||
"""Test that get_memory_fragment returns empty graph when entity not found (line 85)."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.project_graph_from_db = AsyncMock(
|
||||
|
||||
# Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called
|
||||
mock_fragment = MagicMock(spec=CogneeGraph)
|
||||
mock_fragment.project_graph_from_db = AsyncMock(
|
||||
side_effect=EntityNotFoundError("Entity not found")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
fragment = await get_memory_fragment()
|
||||
result = await get_memory_fragment()
|
||||
|
||||
assert isinstance(fragment, CogneeGraph)
|
||||
assert len(fragment.nodes) == 0
|
||||
# Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85)
|
||||
assert result == mock_fragment
|
||||
mock_fragment.project_graph_from_db.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -606,3 +618,200 @@ async def test_brute_force_triplet_search_mixed_empty_collections():
|
|||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
def test_format_triplets():
|
||||
"""Test format_triplets function."""
|
||||
mock_edge = MagicMock()
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
|
||||
mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"}
|
||||
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"}
|
||||
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"}
|
||||
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
result = format_triplets([mock_edge])
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Node1" in result
|
||||
assert "Node2" in result
|
||||
assert "relates_to" in result
|
||||
assert "connects" in result
|
||||
|
||||
|
||||
def test_format_triplets_with_none_values():
|
||||
"""Test format_triplets filters out None values."""
|
||||
mock_edge = MagicMock()
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
|
||||
mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"}
|
||||
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None}
|
||||
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None}
|
||||
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
result = format_triplets([mock_edge])
|
||||
|
||||
assert "Node1" in result
|
||||
assert "Node2" in result
|
||||
assert "relates_to" in result
|
||||
assert "None" not in result or result.count("None") == 0
|
||||
|
||||
|
||||
def test_format_triplets_with_nested_dict():
|
||||
"""Test format_triplets handles nested dict attributes (lines 23-35)."""
|
||||
mock_edge = MagicMock()
|
||||
mock_node1 = MagicMock()
|
||||
mock_node2 = MagicMock()
|
||||
|
||||
mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}}
|
||||
mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}}
|
||||
mock_edge.attributes = {"relationship_name": "relates_to"}
|
||||
|
||||
mock_edge.node1 = mock_node1
|
||||
mock_edge.node2 = mock_node2
|
||||
|
||||
result = format_triplets([mock_edge])
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Node1" in result
|
||||
assert "Node2" in result
|
||||
assert "relates_to" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_vector_engine_init_error():
|
||||
"""Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine"
|
||||
) as mock_get_vector_engine,
|
||||
):
|
||||
mock_get_vector_engine.side_effect = Exception("Initialization error")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Initialization error"):
|
||||
await brute_force_triplet_search(query="test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_collection_not_found_error():
|
||||
"""Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = mock_embedding_engine
|
||||
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
mock_vector_engine.search = AsyncMock(
|
||||
side_effect=[
|
||||
CollectionNotFoundError("Collection not found"),
|
||||
[],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=CogneeGraph(),
|
||||
),
|
||||
):
|
||||
result = await brute_force_triplet_search(
|
||||
query="test query", collections=["missing_collection", "existing_collection"]
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_generic_exception():
|
||||
"""Test brute_force_triplet_search handles generic exceptions (lines 209-217)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = mock_embedding_engine
|
||||
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error"))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception, match="Generic error"):
|
||||
await brute_force_triplet_search(query="test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none():
|
||||
"""Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = mock_embedding_engine
|
||||
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
|
||||
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
|
||||
|
||||
mock_fragment = AsyncMock()
|
||||
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
|
||||
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
|
||||
mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(query="test query", node_name=["Node1"])
|
||||
|
||||
assert mock_get_fragment.called
|
||||
call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {}
|
||||
assert call_kwargs.get("relevant_ids_to_filter") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_collection_not_found_at_top_level():
|
||||
"""Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = mock_embedding_engine
|
||||
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
|
||||
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
|
||||
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
|
||||
|
||||
mock_fragment = AsyncMock()
|
||||
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
|
||||
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
|
||||
mock_fragment.calculate_top_triplet_importances = AsyncMock(
|
||||
side_effect=CollectionNotFoundError("Collection not found")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
result = await brute_force_triplet_search(query="test query")
|
||||
|
||||
assert result == []
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue