diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index a39ef50e1..ce84c1423 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Type +from typing import List, Optional, Type, Union from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError @@ -72,8 +72,18 @@ async def _get_top_triplet_importances( triplet_distance_penalty: float, wide_search_limit: Optional[int], top_k: int, -) -> List[Edge]: - """Creates memory fragment (if needed), maps distances, and calculates top triplet importances.""" + query_list_length: Optional[int] = None, +) -> Union[List[Edge], List[List[Edge]]]: + """Creates memory fragment (if needed), maps distances, and calculates top triplet importances. + + Args: + query_list_length: Number of queries in batch mode (None for single-query mode). + When None, node_distances/edge_distances are flat lists; when set, they are list-of-lists. + + Returns: + List[Edge]: For single-query mode (query_list_length is None). + List[List[Edge]]: For batch mode (query_list_length is set), one list per query. + """ if memory_fragment is None: if wide_search_limit is None: relevant_node_ids = None @@ -89,17 +99,20 @@ async def _get_top_triplet_importances( ) await memory_fragment.map_vector_distances_to_graph_nodes( - node_distances=vector_search.node_distances + node_distances=vector_search.node_distances, query_list_length=query_list_length ) await memory_fragment.map_vector_distances_to_graph_edges( - edge_distances=vector_search.edge_distances + edge_distances=vector_search.edge_distances, query_list_length=query_list_length ) - return await memory_fragment.calculate_top_triplet_importances(k=top_k) + return await memory_fragment.calculate_top_triplet_importances( + k=top_k, query_list_length=query_list_length + ) async def brute_force_triplet_search( - query: str, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, top_k: int = 5, collections: Optional[List[str]] = None, properties_to_project: Optional[List[str]] = None, @@ -108,30 +121,49 @@ async def brute_force_triplet_search( node_name: Optional[List[str]] = None, wide_search_top_k: Optional[int] = 100, triplet_distance_penalty: Optional[float] = 3.5, -) -> List[Edge]: +) -> Union[List[Edge], List[List[Edge]]]: """ Performs a brute force search to retrieve the top triplets from the graph. Args: - query (str): The search query. + query (Optional[str]): The search query (single query mode). Exactly one of query or query_batch must be provided. + query_batch (Optional[List[str]]): List of search queries (batch mode). Exactly one of query or query_batch must be provided. top_k (int): The number of top results to retrieve. collections (Optional[List[str]]): List of collections to query. properties_to_project (Optional[List[str]]): List of properties to project. memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse. node_type: node type to filter node_name: node name to filter - wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections + wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections. + Ignored in batch mode (always None to project full graph). triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection Returns: - list: The top triplet results. + List[Edge]: The top triplet results for single query mode (flat list). + List[List[Edge]]: List of top triplet results (one per query) for batch mode (list-of-lists). + + Note: + In single-query mode, node_distances and edge_distances are stored as flat lists. + In batch mode, they are stored as list-of-lists (one list per query). """ - if not query or not isinstance(query, str): + if query is not None and query_batch is not None: + raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") + if query is None and query_batch is None: + raise ValueError("Must provide either 'query' or 'query_batch'.") + if query is not None and (not query or not isinstance(query, str)): raise ValueError("The query must be a non-empty string.") + if query_batch is not None: + if not isinstance(query_batch, list) or not query_batch: + raise ValueError("query_batch must be a non-empty list of strings.") + if not all(isinstance(q, str) and q for q in query_batch): + raise ValueError("All items in query_batch must be non-empty strings.") if top_k <= 0: raise ValueError("top_k must be a positive integer.") - wide_search_limit = wide_search_top_k if node_name is None else None + query_list_length = len(query_batch) if query_batch is not None else None + wide_search_limit = ( + None if query_list_length else (wide_search_top_k if node_name is None else None) + ) if collections is None: collections = [ @@ -148,13 +180,16 @@ async def brute_force_triplet_search( vector_search = NodeEdgeVectorSearch() await vector_search.embed_and_retrieve_distances( - query=query, collections=collections, wide_search_limit=wide_search_limit + query=None if query_list_length else query, + query_batch=query_batch if query_list_length else None, + collections=collections, + wide_search_limit=wide_search_limit, ) if not vector_search.has_results(): - return [] + return [[] for _ in range(query_list_length)] if query_list_length else [] - return await _get_top_triplet_importances( + results = await _get_top_triplet_importances( memory_fragment, vector_search, properties_to_project, @@ -163,13 +198,16 @@ async def brute_force_triplet_search( triplet_distance_penalty, wide_search_limit, top_k, + query_list_length=query_list_length, ) + + return results except CollectionNotFoundError: - return [] + return [[] for _ in range(query_list_length)] if query_list_length else [] except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s", - query, + query_batch if query_list_length else [query], error, ) raise error