feat: enable batch search in brute_force_triplet_search

This commit is contained in:
lxobr 2026-01-12 14:40:17 +01:00
parent 5ac288afa3
commit 7833189001

View file

@ -1,4 +1,4 @@
from typing import List, Optional, Type
from typing import List, Optional, Type, Union
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
@ -72,8 +72,18 @@ async def _get_top_triplet_importances(
triplet_distance_penalty: float,
wide_search_limit: Optional[int],
top_k: int,
) -> List[Edge]:
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances."""
query_list_length: Optional[int] = None,
) -> Union[List[Edge], List[List[Edge]]]:
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances.
Args:
query_list_length: Number of queries in batch mode (None for single-query mode).
When None, node_distances/edge_distances are flat lists; when set, they are list-of-lists.
Returns:
List[Edge]: For single-query mode (query_list_length is None).
List[List[Edge]]: For batch mode (query_list_length is set), one list per query.
"""
if memory_fragment is None:
if wide_search_limit is None:
relevant_node_ids = None
@ -89,17 +99,20 @@ async def _get_top_triplet_importances(
)
await memory_fragment.map_vector_distances_to_graph_nodes(
node_distances=vector_search.node_distances
node_distances=vector_search.node_distances, query_list_length=query_list_length
)
await memory_fragment.map_vector_distances_to_graph_edges(
edge_distances=vector_search.edge_distances
edge_distances=vector_search.edge_distances, query_list_length=query_list_length
)
return await memory_fragment.calculate_top_triplet_importances(k=top_k)
return await memory_fragment.calculate_top_triplet_importances(
k=top_k, query_list_length=query_list_length
)
async def brute_force_triplet_search(
query: str,
query: Optional[str] = None,
query_batch: Optional[List[str]] = None,
top_k: int = 5,
collections: Optional[List[str]] = None,
properties_to_project: Optional[List[str]] = None,
@ -108,30 +121,49 @@ async def brute_force_triplet_search(
node_name: Optional[List[str]] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Edge]:
) -> Union[List[Edge], List[List[Edge]]]:
"""
Performs a brute force search to retrieve the top triplets from the graph.
Args:
query (str): The search query.
query (Optional[str]): The search query (single query mode). Exactly one of query or query_batch must be provided.
query_batch (Optional[List[str]]): List of search queries (batch mode). Exactly one of query or query_batch must be provided.
top_k (int): The number of top results to retrieve.
collections (Optional[List[str]]): List of collections to query.
properties_to_project (Optional[List[str]]): List of properties to project.
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
node_type: node type to filter
node_name: node name to filter
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections.
Ignored in batch mode (always None to project full graph).
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
Returns:
list: The top triplet results.
List[Edge]: The top triplet results for single query mode (flat list).
List[List[Edge]]: List of top triplet results (one per query) for batch mode (list-of-lists).
Note:
In single-query mode, node_distances and edge_distances are stored as flat lists.
In batch mode, they are stored as list-of-lists (one list per query).
"""
if not query or not isinstance(query, str):
if query is not None and query_batch is not None:
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
if query is None and query_batch is None:
raise ValueError("Must provide either 'query' or 'query_batch'.")
if query is not None and (not query or not isinstance(query, str)):
raise ValueError("The query must be a non-empty string.")
if query_batch is not None:
if not isinstance(query_batch, list) or not query_batch:
raise ValueError("query_batch must be a non-empty list of strings.")
if not all(isinstance(q, str) and q for q in query_batch):
raise ValueError("All items in query_batch must be non-empty strings.")
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
wide_search_limit = wide_search_top_k if node_name is None else None
query_list_length = len(query_batch) if query_batch is not None else None
wide_search_limit = (
None if query_list_length else (wide_search_top_k if node_name is None else None)
)
if collections is None:
collections = [
@ -148,13 +180,16 @@ async def brute_force_triplet_search(
vector_search = NodeEdgeVectorSearch()
await vector_search.embed_and_retrieve_distances(
query=query, collections=collections, wide_search_limit=wide_search_limit
query=None if query_list_length else query,
query_batch=query_batch if query_list_length else None,
collections=collections,
wide_search_limit=wide_search_limit,
)
if not vector_search.has_results():
return []
return [[] for _ in range(query_list_length)] if query_list_length else []
return await _get_top_triplet_importances(
results = await _get_top_triplet_importances(
memory_fragment,
vector_search,
properties_to_project,
@ -163,13 +198,16 @@ async def brute_force_triplet_search(
triplet_distance_penalty,
wide_search_limit,
top_k,
query_list_length=query_list_length,
)
return results
except CollectionNotFoundError:
return []
return [[] for _ in range(query_list_length)] if query_list_length else []
except Exception as error:
logger.error(
"Error during brute force search for query: %s. Error: %s",
query,
query_batch if query_list_length else [query],
error,
)
raise error