diff --git a/cognee/modules/retrieval/exceptions/__init__.py b/cognee/modules/retrieval/exceptions/__init__.py index 1b98cddcd..feb84a962 100644 --- a/cognee/modules/retrieval/exceptions/__init__.py +++ b/cognee/modules/retrieval/exceptions/__init__.py @@ -4,4 +4,4 @@ Custom exceptions for the Cognee API. This module defines a set of exceptions for handling various data errors """ -from .exceptions import SearchTypeNotSupported, CypherSearchError +from .exceptions import SearchTypeNotSupported, CypherSearchError, CollectionDistancesNotFoundError diff --git a/cognee/modules/retrieval/exceptions/exceptions.py b/cognee/modules/retrieval/exceptions/exceptions.py index 1b7c34251..7e33e3a5f 100644 --- a/cognee/modules/retrieval/exceptions/exceptions.py +++ b/cognee/modules/retrieval/exceptions/exceptions.py @@ -2,6 +2,16 @@ from cognee.exceptions import CogneeApiError from fastapi import status +class CollectionDistancesNotFoundError(CogneeApiError): + def __init__( + self, + message: str = "No distances found between the query and collections. It is possible that the given collection names don't exist.", + name: str = "CollectionDistancesNotFoundError", + status_code: int = status.HTTP_404_NOT_FOUND, + ): + super().__init__(message, name, status_code) + + class SearchTypeNotSupported(CogneeApiError): def __init__( self, diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 1e3e1ce00..bef4493b4 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -8,6 +8,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.users.methods import get_default_user from cognee.modules.users.models import User from cognee.shared.utils import send_telemetry +from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError logger = get_logger(level=ERROR) @@ -149,6 +150,9 @@ async def brute_force_search( ] ) + if all(not item for item in results): + raise CollectionDistancesNotFoundError() + node_distances = {collection: result for collection, result in zip(collections, results)} await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) diff --git a/cognee/tests/unit/modules/retrieval/utils/brute_force_triplet_search_test.py b/cognee/tests/unit/modules/retrieval/utils/brute_force_triplet_search_test.py new file mode 100644 index 000000000..9af0af42a --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/utils/brute_force_triplet_search_test.py @@ -0,0 +1,48 @@ +import pytest +from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError +from cognee.modules.users.models import User +from cognee.modules.retrieval.utils.brute_force_triplet_search import ( + brute_force_search, + brute_force_triplet_search, +) +from unittest.mock import AsyncMock, patch + + +@pytest.mark.asyncio +@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine") +async def test_brute_force_search_collection_not_found(mock_get_vector_engine): + user = User(id="test_user") + query = "test query" + collections = ["nonexistent_collection"] + top_k = 5 + mock_memory_fragment = AsyncMock() + mock_vector_engine = AsyncMock() + mock_vector_engine.get_distance_from_collection_elements.return_value = [] + mock_get_vector_engine.return_value = mock_vector_engine + + with pytest.raises(Exception) as exc_info: + await brute_force_search( + query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment + ) + + assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError) + + +@pytest.mark.asyncio +@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine") +async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_engine): + user = User(id="test_user") + query = "test query" + collections = ["nonexistent_collection"] + top_k = 5 + mock_memory_fragment = AsyncMock() + mock_vector_engine = AsyncMock() + mock_vector_engine.get_distance_from_collection_elements.return_value = [] + mock_get_vector_engine.return_value = mock_vector_engine + + with pytest.raises(Exception) as exc_info: + await brute_force_triplet_search( + query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment + ) + + assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)