adds 100% coverage to brute force triplet search

This commit is contained in:
hajdul88 2025-12-11 12:56:42 +01:00
parent c2e402dbab
commit 22af6a050a

View file

@ -1,12 +1,14 @@
import pytest import pytest
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.utils.brute_force_triplet_search import ( from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_triplet_search, brute_force_triplet_search,
get_memory_fragment, get_memory_fragment,
format_triplets,
) )
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
class MockScoredResult: class MockScoredResult:
@ -354,20 +356,30 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
"""Test that get_memory_fragment returns empty graph when entity not found.""" """Test that get_memory_fragment returns empty graph when entity not found (line 85)."""
mock_graph_engine = AsyncMock() mock_graph_engine = AsyncMock()
mock_graph_engine.project_graph_from_db = AsyncMock(
# Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called
mock_fragment = MagicMock(spec=CogneeGraph)
mock_fragment.project_graph_from_db = AsyncMock(
side_effect=EntityNotFoundError("Entity not found") side_effect=EntityNotFoundError("Entity not found")
) )
with patch( with (
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", patch(
return_value=mock_graph_engine, "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph",
return_value=mock_fragment,
),
): ):
fragment = await get_memory_fragment() result = await get_memory_fragment()
assert isinstance(fragment, CogneeGraph) # Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85)
assert len(fragment.nodes) == 0 assert result == mock_fragment
mock_fragment.project_graph_from_db.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -606,3 +618,200 @@ async def test_brute_force_triplet_search_mixed_empty_collections():
call_kwargs = mock_get_fragment_fn.call_args[1] call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
def test_format_triplets():
"""Test format_triplets function."""
mock_edge = MagicMock()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"}
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"}
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"}
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
result = format_triplets([mock_edge])
assert isinstance(result, str)
assert "Node1" in result
assert "Node2" in result
assert "relates_to" in result
assert "connects" in result
def test_format_triplets_with_none_values():
"""Test format_triplets filters out None values."""
mock_edge = MagicMock()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"}
mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None}
mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None}
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
result = format_triplets([mock_edge])
assert "Node1" in result
assert "Node2" in result
assert "relates_to" in result
assert "None" not in result or result.count("None") == 0
def test_format_triplets_with_nested_dict():
"""Test format_triplets handles nested dict attributes (lines 23-35)."""
mock_edge = MagicMock()
mock_node1 = MagicMock()
mock_node2 = MagicMock()
mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}}
mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}}
mock_edge.attributes = {"relationship_name": "relates_to"}
mock_edge.node1 = mock_node1
mock_edge.node2 = mock_node2
result = format_triplets([mock_edge])
assert isinstance(result, str)
assert "Node1" in result
assert "Node2" in result
assert "relates_to" in result
@pytest.mark.asyncio
async def test_brute_force_triplet_search_vector_engine_init_error():
"""Test brute_force_triplet_search handles vector engine initialization error (lines 145-147)."""
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine"
) as mock_get_vector_engine,
):
mock_get_vector_engine.side_effect = Exception("Initialization error")
with pytest.raises(RuntimeError, match="Initialization error"):
await brute_force_triplet_search(query="test query")
@pytest.mark.asyncio
async def test_brute_force_triplet_search_collection_not_found_error():
"""Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(
side_effect=[
CollectionNotFoundError("Collection not found"),
[],
[],
]
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=CogneeGraph(),
),
):
result = await brute_force_triplet_search(
query="test query", collections=["missing_collection", "existing_collection"]
)
assert result == []
@pytest.mark.asyncio
async def test_brute_force_triplet_search_generic_exception():
"""Test brute_force_triplet_search handles generic exceptions (lines 209-217)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error"))
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
):
with pytest.raises(Exception, match="Generic error"):
await brute_force_triplet_search(query="test query")
@pytest.mark.asyncio
async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none():
"""Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
mock_fragment = AsyncMock()
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[])
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
await brute_force_triplet_search(query="test query", node_name=["Node1"])
assert mock_get_fragment.called
call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {}
assert call_kwargs.get("relevant_ids_to_filter") is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_collection_not_found_at_top_level():
"""Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210)."""
mock_vector_engine = AsyncMock()
mock_embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine = mock_embedding_engine
mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"})
mock_vector_engine.search = AsyncMock(return_value=[mock_result])
mock_fragment = AsyncMock()
mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock()
mock_fragment.map_vector_distances_to_graph_edges = AsyncMock()
mock_fragment.calculate_top_triplet_importances = AsyncMock(
side_effect=CollectionNotFoundError("Collection not found")
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
),
):
result = await brute_force_triplet_search(query="test query")
assert result == []