feat: enable batch search in brute_force_triplet_search
This commit is contained in:
parent
5ac288afa3
commit
7833189001
1 changed files with 56 additions and 18 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional, Type
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
|
|
@ -72,8 +72,18 @@ async def _get_top_triplet_importances(
|
|||
triplet_distance_penalty: float,
|
||||
wide_search_limit: Optional[int],
|
||||
top_k: int,
|
||||
) -> List[Edge]:
|
||||
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances."""
|
||||
query_list_length: Optional[int] = None,
|
||||
) -> Union[List[Edge], List[List[Edge]]]:
|
||||
"""Creates memory fragment (if needed), maps distances, and calculates top triplet importances.
|
||||
|
||||
Args:
|
||||
query_list_length: Number of queries in batch mode (None for single-query mode).
|
||||
When None, node_distances/edge_distances are flat lists; when set, they are list-of-lists.
|
||||
|
||||
Returns:
|
||||
List[Edge]: For single-query mode (query_list_length is None).
|
||||
List[List[Edge]]: For batch mode (query_list_length is set), one list per query.
|
||||
"""
|
||||
if memory_fragment is None:
|
||||
if wide_search_limit is None:
|
||||
relevant_node_ids = None
|
||||
|
|
@ -89,17 +99,20 @@ async def _get_top_triplet_importances(
|
|||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(
|
||||
node_distances=vector_search.node_distances
|
||||
node_distances=vector_search.node_distances, query_list_length=query_list_length
|
||||
)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||
edge_distances=vector_search.edge_distances
|
||||
edge_distances=vector_search.edge_distances, query_list_length=query_list_length
|
||||
)
|
||||
|
||||
return await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||
return await memory_fragment.calculate_top_triplet_importances(
|
||||
k=top_k, query_list_length=query_list_length
|
||||
)
|
||||
|
||||
|
||||
async def brute_force_triplet_search(
|
||||
query: str,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
top_k: int = 5,
|
||||
collections: Optional[List[str]] = None,
|
||||
properties_to_project: Optional[List[str]] = None,
|
||||
|
|
@ -108,30 +121,49 @@ async def brute_force_triplet_search(
|
|||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> List[Edge]:
|
||||
) -> Union[List[Edge], List[List[Edge]]]:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
query (Optional[str]): The search query (single query mode). Exactly one of query or query_batch must be provided.
|
||||
query_batch (Optional[List[str]]): List of search queries (batch mode). Exactly one of query or query_batch must be provided.
|
||||
top_k (int): The number of top results to retrieve.
|
||||
collections (Optional[List[str]]): List of collections to query.
|
||||
properties_to_project (Optional[List[str]]): List of properties to project.
|
||||
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
||||
node_type: node type to filter
|
||||
node_name: node name to filter
|
||||
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
|
||||
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections.
|
||||
Ignored in batch mode (always None to project full graph).
|
||||
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
|
||||
|
||||
Returns:
|
||||
list: The top triplet results.
|
||||
List[Edge]: The top triplet results for single query mode (flat list).
|
||||
List[List[Edge]]: List of top triplet results (one per query) for batch mode (list-of-lists).
|
||||
|
||||
Note:
|
||||
In single-query mode, node_distances and edge_distances are stored as flat lists.
|
||||
In batch mode, they are stored as list-of-lists (one list per query).
|
||||
"""
|
||||
if not query or not isinstance(query, str):
|
||||
if query is not None and query_batch is not None:
|
||||
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
|
||||
if query is None and query_batch is None:
|
||||
raise ValueError("Must provide either 'query' or 'query_batch'.")
|
||||
if query is not None and (not query or not isinstance(query, str)):
|
||||
raise ValueError("The query must be a non-empty string.")
|
||||
if query_batch is not None:
|
||||
if not isinstance(query_batch, list) or not query_batch:
|
||||
raise ValueError("query_batch must be a non-empty list of strings.")
|
||||
if not all(isinstance(q, str) and q for q in query_batch):
|
||||
raise ValueError("All items in query_batch must be non-empty strings.")
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
wide_search_limit = wide_search_top_k if node_name is None else None
|
||||
query_list_length = len(query_batch) if query_batch is not None else None
|
||||
wide_search_limit = (
|
||||
None if query_list_length else (wide_search_top_k if node_name is None else None)
|
||||
)
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
@ -148,13 +180,16 @@ async def brute_force_triplet_search(
|
|||
vector_search = NodeEdgeVectorSearch()
|
||||
|
||||
await vector_search.embed_and_retrieve_distances(
|
||||
query=query, collections=collections, wide_search_limit=wide_search_limit
|
||||
query=None if query_list_length else query,
|
||||
query_batch=query_batch if query_list_length else None,
|
||||
collections=collections,
|
||||
wide_search_limit=wide_search_limit,
|
||||
)
|
||||
|
||||
if not vector_search.has_results():
|
||||
return []
|
||||
return [[] for _ in range(query_list_length)] if query_list_length else []
|
||||
|
||||
return await _get_top_triplet_importances(
|
||||
results = await _get_top_triplet_importances(
|
||||
memory_fragment,
|
||||
vector_search,
|
||||
properties_to_project,
|
||||
|
|
@ -163,13 +198,16 @@ async def brute_force_triplet_search(
|
|||
triplet_distance_penalty,
|
||||
wide_search_limit,
|
||||
top_k,
|
||||
query_list_length=query_list_length,
|
||||
)
|
||||
|
||||
return results
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
return [[] for _ in range(query_list_length)] if query_list_length else []
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error during brute force search for query: %s. Error: %s",
|
||||
query,
|
||||
query_batch if query_list_length else [query],
|
||||
error,
|
||||
)
|
||||
raise error
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue