chore: handle empty distance list in brute force search [cog-1424] (#654)
<!-- .github/pull_request_template.md --> ## Description - handle empty distance list in brute force search - unit tests ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin --------- Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
parent
14237f64e2
commit
936fcf7cd7
4 changed files with 63 additions and 1 deletions
|
|
@ -4,4 +4,4 @@ Custom exceptions for the Cognee API.
|
||||||
This module defines a set of exceptions for handling various data errors
|
This module defines a set of exceptions for handling various data errors
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .exceptions import SearchTypeNotSupported, CypherSearchError
|
from .exceptions import SearchTypeNotSupported, CypherSearchError, CollectionDistancesNotFoundError
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,16 @@ from cognee.exceptions import CogneeApiError
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionDistancesNotFoundError(CogneeApiError):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "No distances found between the query and collections. It is possible that the given collection names don't exist.",
|
||||||
|
name: str = "CollectionDistancesNotFoundError",
|
||||||
|
status_code: int = status.HTTP_404_NOT_FOUND,
|
||||||
|
):
|
||||||
|
super().__init__(message, name, status_code)
|
||||||
|
|
||||||
|
|
||||||
class SearchTypeNotSupported(CogneeApiError):
|
class SearchTypeNotSupported(CogneeApiError):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
|
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
|
||||||
|
|
||||||
logger = get_logger(level=ERROR)
|
logger = get_logger(level=ERROR)
|
||||||
|
|
||||||
|
|
@ -149,6 +150,9 @@ async def brute_force_search(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if all(not item for item in results):
|
||||||
|
raise CollectionDistancesNotFoundError()
|
||||||
|
|
||||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||||
|
|
||||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
import pytest
|
||||||
|
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||||
|
brute_force_search,
|
||||||
|
brute_force_triplet_search,
|
||||||
|
)
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
|
||||||
|
async def test_brute_force_search_collection_not_found(mock_get_vector_engine):
|
||||||
|
user = User(id="test_user")
|
||||||
|
query = "test query"
|
||||||
|
collections = ["nonexistent_collection"]
|
||||||
|
top_k = 5
|
||||||
|
mock_memory_fragment = AsyncMock()
|
||||||
|
mock_vector_engine = AsyncMock()
|
||||||
|
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
||||||
|
mock_get_vector_engine.return_value = mock_vector_engine
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await brute_force_search(
|
||||||
|
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
|
||||||
|
async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_engine):
|
||||||
|
user = User(id="test_user")
|
||||||
|
query = "test query"
|
||||||
|
collections = ["nonexistent_collection"]
|
||||||
|
top_k = 5
|
||||||
|
mock_memory_fragment = AsyncMock()
|
||||||
|
mock_vector_engine = AsyncMock()
|
||||||
|
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
||||||
|
mock_get_vector_engine.return_value = mock_vector_engine
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await brute_force_triplet_search(
|
||||||
|
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(exc_info.value.__cause__, CollectionDistancesNotFoundError)
|
||||||
Loading…
Add table
Reference in a new issue