chore: handle empty distance list in brute force search [cog-1424] (#654)
<!-- .github/pull_request_template.md --> ## Description - handle empty distance list in brute force search - unit tests ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin --------- Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
parent
14237f64e2
commit
936fcf7cd7
4 changed files with 63 additions and 1 deletions
|
|
@ -4,4 +4,4 @@ Custom exceptions for the Cognee API.
|
|||
This module defines a set of exceptions for handling various data errors
|
||||
"""
|
||||
|
||||
from .exceptions import SearchTypeNotSupported, CypherSearchError
|
||||
from .exceptions import SearchTypeNotSupported, CypherSearchError, CollectionDistancesNotFoundError
|
||||
|
|
|
|||
|
|
@ -2,6 +2,16 @@ from cognee.exceptions import CogneeApiError
|
|||
from fastapi import status
|
||||
|
||||
|
||||
class CollectionDistancesNotFoundError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "No distances found between the query and collections. It is possible that the given collection names don't exist.",
|
||||
name: str = "CollectionDistancesNotFoundError",
|
||||
status_code: int = status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class SearchTypeNotSupported(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
|||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
|
||||
|
||||
logger = get_logger(level=ERROR)
|
||||
|
||||
|
|
@ -149,6 +150,9 @@ async def brute_force_search(
|
|||
]
|
||||
)
|
||||
|
||||
if all(not item for item in results):
|
||||
raise CollectionDistancesNotFoundError()
|
||||
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||
brute_force_search,
|
||||
brute_force_triplet_search,
|
||||
)
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
|
||||
async def test_brute_force_search_collection_not_found(mock_get_vector_engine):
|
||||
user = User(id="test_user")
|
||||
query = "test query"
|
||||
collections = ["nonexistent_collection"]
|
||||
top_k = 5
|
||||
mock_memory_fragment = AsyncMock()
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await brute_force_search(
|
||||
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
||||
)
|
||||
|
||||
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
|
||||
async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_engine):
|
||||
user = User(id="test_user")
|
||||
query = "test query"
|
||||
collections = ["nonexistent_collection"]
|
||||
top_k = 5
|
||||
mock_memory_fragment = AsyncMock()
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await brute_force_triplet_search(
|
||||
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
||||
)
|
||||
|
||||
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
|
||||
Loading…
Add table
Reference in a new issue