feat: enable batch search in brute_force_triplet_search

This commit is contained in:
lxobr 2026-01-12 14:40:17 +01:00
parent 5ac288afa3
commit 7833189001

View file

@ -1,4 +1,4 @@
from typing import List, Optional, Type from typing import List, Optional, Type, Union
from cognee.shared.logging_utils import get_logger, ERROR from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
@ -72,8 +72,18 @@ async def _get_top_triplet_importances(
triplet_distance_penalty: float, triplet_distance_penalty: float,
wide_search_limit: Optional[int], wide_search_limit: Optional[int],
top_k: int, top_k: int,
) -> List[Edge]: query_list_length: Optional[int] = None,
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances.""" ) -> Union[List[Edge], List[List[Edge]]]:
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances.
Args:
query_list_length: Number of queries in batch mode (None for single-query mode).
When None, node_distances/edge_distances are flat lists; when set, they are list-of-lists.
Returns:
List[Edge]: For single-query mode (query_list_length is None).
List[List[Edge]]: For batch mode (query_list_length is set), one list per query.
"""
if memory_fragment is None: if memory_fragment is None:
if wide_search_limit is None: if wide_search_limit is None:
relevant_node_ids = None relevant_node_ids = None
@ -89,17 +99,20 @@ async def _get_top_triplet_importances(
) )
await memory_fragment.map_vector_distances_to_graph_nodes( await memory_fragment.map_vector_distances_to_graph_nodes(
node_distances=vector_search.node_distances node_distances=vector_search.node_distances, query_list_length=query_list_length
) )
await memory_fragment.map_vector_distances_to_graph_edges( await memory_fragment.map_vector_distances_to_graph_edges(
edge_distances=vector_search.edge_distances edge_distances=vector_search.edge_distances, query_list_length=query_list_length
) )
return await memory_fragment.calculate_top_triplet_importances(k=top_k) return await memory_fragment.calculate_top_triplet_importances(
k=top_k, query_list_length=query_list_length
)
async def brute_force_triplet_search( async def brute_force_triplet_search(
query: str, query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
top_k: int = 5, top_k: int = 5,
collections: Optional[List[str]] = None, collections: Optional[List[str]] = None,
properties_to_project: Optional[List[str]] = None, properties_to_project: Optional[List[str]] = None,
@ -108,30 +121,49 @@ async def brute_force_triplet_search(
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
wide_search_top_k: Optional[int] = 100, wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5, triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Edge]: ) -> Union[List[Edge], List[List[Edge]]]:
""" """
Performs a brute force search to retrieve the top triplets from the graph. Performs a brute force search to retrieve the top triplets from the graph.
Args: Args:
query (str): The search query. query (Optional[str]): The search query (single query mode). Exactly one of query or query_batch must be provided.
query_batch (Optional[List[str]]): List of search queries (batch mode). Exactly one of query or query_batch must be provided.
top_k (int): The number of top results to retrieve. top_k (int): The number of top results to retrieve.
collections (Optional[List[str]]): List of collections to query. collections (Optional[List[str]]): List of collections to query.
properties_to_project (Optional[List[str]]): List of properties to project. properties_to_project (Optional[List[str]]): List of properties to project.
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse. memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
node_type: node type to filter node_type: node type to filter
node_name: node name to filter node_name: node name to filter
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections.
Ignored in batch mode (always None to project full graph).
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
Returns: Returns:
list: The top triplet results. List[Edge]: The top triplet results for single query mode (flat list).
List[List[Edge]]: List of top triplet results (one per query) for batch mode (list-of-lists).
Note:
In single-query mode, node_distances and edge_distances are stored as flat lists.
In batch mode, they are stored as list-of-lists (one list per query).
""" """
if not query or not isinstance(query, str): if query is not None and query_batch is not None:
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
if query is None and query_batch is None:
raise ValueError("Must provide either 'query' or 'query_batch'.")
if query is not None and (not query or not isinstance(query, str)):
raise ValueError("The query must be a non-empty string.") raise ValueError("The query must be a non-empty string.")
if query_batch is not None:
if not isinstance(query_batch, list) or not query_batch:
raise ValueError("query_batch must be a non-empty list of strings.")
if not all(isinstance(q, str) and q for q in query_batch):
raise ValueError("All items in query_batch must be non-empty strings.")
if top_k <= 0: if top_k <= 0:
raise ValueError("top_k must be a positive integer.") raise ValueError("top_k must be a positive integer.")
wide_search_limit = wide_search_top_k if node_name is None else None query_list_length = len(query_batch) if query_batch is not None else None
wide_search_limit = (
None if query_list_length else (wide_search_top_k if node_name is None else None)
)
if collections is None: if collections is None:
collections = [ collections = [
@ -148,13 +180,16 @@ async def brute_force_triplet_search(
vector_search = NodeEdgeVectorSearch() vector_search = NodeEdgeVectorSearch()
await vector_search.embed_and_retrieve_distances( await vector_search.embed_and_retrieve_distances(
query=query, collections=collections, wide_search_limit=wide_search_limit query=None if query_list_length else query,
query_batch=query_batch if query_list_length else None,
collections=collections,
wide_search_limit=wide_search_limit,
) )
if not vector_search.has_results(): if not vector_search.has_results():
return [] return [[] for _ in range(query_list_length)] if query_list_length else []
return await _get_top_triplet_importances( results = await _get_top_triplet_importances(
memory_fragment, memory_fragment,
vector_search, vector_search,
properties_to_project, properties_to_project,
@ -163,13 +198,16 @@ async def brute_force_triplet_search(
triplet_distance_penalty, triplet_distance_penalty,
wide_search_limit, wide_search_limit,
top_k, top_k,
query_list_length=query_list_length,
) )
return results
except CollectionNotFoundError: except CollectionNotFoundError:
return [] return [[] for _ in range(query_list_length)] if query_list_length else []
except Exception as error: except Exception as error:
logger.error( logger.error(
"Error during brute force search for query: %s. Error: %s", "Error during brute force search for query: %s. Error: %s",
query, query_batch if query_list_length else [query],
error, error,
) )
raise error raise error