chore: handle empty distance list in brute force search [cog-1424] (#654)

<!-- .github/pull_request_template.md -->

## Description
- handle empty distance list in brute force search
- unit tests

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin

---------

Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
alekszievr 2025-03-25 15:50:02 +01:00 committed by GitHub
parent 14237f64e2
commit 936fcf7cd7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 63 additions and 1 deletions

View file

@ -4,4 +4,4 @@ Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various data errors
"""
from .exceptions import SearchTypeNotSupported, CypherSearchError
from .exceptions import SearchTypeNotSupported, CypherSearchError, CollectionDistancesNotFoundError

View file

@ -2,6 +2,16 @@ from cognee.exceptions import CogneeApiError
from fastapi import status
class CollectionDistancesNotFoundError(CogneeApiError):
def __init__(
self,
message: str = "No distances found between the query and collections. It is possible that the given collection names don't exist.",
name: str = "CollectionDistancesNotFoundError",
status_code: int = status.HTTP_404_NOT_FOUND,
):
super().__init__(message, name, status_code)
class SearchTypeNotSupported(CogneeApiError):
def __init__(
self,

View file

@ -8,6 +8,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
logger = get_logger(level=ERROR)
@ -149,6 +150,9 @@ async def brute_force_search(
]
)
if all(not item for item in results):
raise CollectionDistancesNotFoundError()
node_distances = {collection: result for collection, result in zip(collections, results)}
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)

View file

@ -0,0 +1,48 @@
import pytest
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
from cognee.modules.users.models import User
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_search,
brute_force_triplet_search,
)
from unittest.mock import AsyncMock, patch
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
async def test_brute_force_search_collection_not_found(mock_get_vector_engine):
user = User(id="test_user")
query = "test query"
collections = ["nonexistent_collection"]
top_k = 5
mock_memory_fragment = AsyncMock()
mock_vector_engine = AsyncMock()
mock_vector_engine.get_distance_from_collection_elements.return_value = []
mock_get_vector_engine.return_value = mock_vector_engine
with pytest.raises(Exception) as exc_info:
await brute_force_search(
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
)
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
@pytest.mark.asyncio
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_engine):
user = User(id="test_user")
query = "test query"
collections = ["nonexistent_collection"]
top_k = 5
mock_memory_fragment = AsyncMock()
mock_vector_engine = AsyncMock()
mock_vector_engine.get_distance_from_collection_elements.return_value = []
mock_get_vector_engine.return_value = mock_vector_engine
with pytest.raises(Exception) as exc_info:
await brute_force_triplet_search(
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
)
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)