From c79af6c8cccb241200d4596a87a068c013854a74 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 9 Jan 2026 10:05:44 +0100 Subject: [PATCH] refactor: brute_force_triplet_search.py and node_edge_vector_search.py --- .../utils/brute_force_triplet_search.py | 170 ++++-------------- .../utils/node_edge_vector_search.py | 81 +++++++++ 2 files changed, 119 insertions(+), 132 deletions(-) create mode 100644 cognee/modules/retrieval/utils/node_edge_vector_search.py diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 50d16edb2..ef805a127 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,14 +1,11 @@ -import asyncio -import time -from typing import Any, List, Optional, Type +from typing import List, Optional, Type from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError -from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.node_edge_vector_search import NodeEdgeVectorSearch logger = get_logger(level=ERROR) @@ -65,122 +62,36 @@ async def get_memory_fragment( return memory_fragment -class TripletSearchContext: - """Pure state container for triplet search operations.""" - - def __init__( - self, - query: str, - top_k: int, - collections: List[str], - properties_to_project: Optional[List[str]], - node_type: Optional[Type], - node_name: Optional[List[str]], - wide_search_limit: Optional[int], - triplet_distance_penalty: float, - ): - self.query = query - self.top_k = top_k - self.collections = collections - self.properties_to_project = properties_to_project - self.node_type = node_type - self.node_name = node_name - self.wide_search_limit = wide_search_limit - self.triplet_distance_penalty = triplet_distance_penalty - - self.query_vector = None - self.node_distances = None - self.edge_distances = None - - def has_results(self) -> bool: - """Checks if any collections returned results.""" - return bool(self.edge_distances or any(self.node_distances.values())) - - def extract_relevant_node_ids(self) -> Optional[List[str]]: - """Extracts unique node IDs from search results to filter graph projection.""" - if self.wide_search_limit is None: - return None - - relevant_node_ids = { - str(getattr(scored_node, "id")) - for score_collection in self.node_distances.values() - if isinstance(score_collection, (list, tuple)) - for scored_node in score_collection - if getattr(scored_node, "id", None) - } - return list(relevant_node_ids) - - def set_distances_from_results(self, search_results: List[List[Any]]): - """Separates search results into node and edge distances.""" - self.node_distances = {} - for collection, result in zip(self.collections, search_results): - if collection == "EdgeType_relationship_name": - self.edge_distances = result - else: - self.node_distances[collection] = result - - -async def _search_single_collection( - vector_engine: Any, search_context: TripletSearchContext, collection_name: str -): - """Searches one collection and returns results or empty list if not found.""" - try: - return await vector_engine.search( - collection_name=collection_name, - query_vector=search_context.query_vector, - limit=search_context.wide_search_limit, +async def _get_top_triplet_importances( + memory_fragment: Optional[CogneeGraph], + vector_search: NodeEdgeVectorSearch, + properties_to_project: Optional[List[str]], + node_type: Optional[Type], + node_name: Optional[List[str]], + triplet_distance_penalty: float, + wide_search_limit: Optional[int], + top_k: int, +) -> List[Edge]: + """Creates memory fragment (if needed), maps distances, and calculates top triplet importances.""" + if memory_fragment is None: + relevant_node_ids = vector_search.extract_relevant_node_ids() if wide_search_limit else None + memory_fragment = await get_memory_fragment( + properties_to_project=properties_to_project, + node_type=node_type, + node_name=node_name, + relevant_ids_to_filter=relevant_node_ids, + triplet_distance_penalty=triplet_distance_penalty, ) - except CollectionNotFoundError: - return [] - -async def _embed_and_retrieve_distances(search_context: TripletSearchContext): - """Embeds query and retrieves vector distances from all collections.""" - vector_engine = get_vector_engine() - - query_embeddings = await vector_engine.embedding_engine.embed_text([search_context.query]) - search_context.query_vector = query_embeddings[0] - - start_time = time.time() - search_tasks = [ - _search_single_collection(vector_engine, search_context, collection) - for collection in search_context.collections - ] - search_results = await asyncio.gather(*search_tasks) - - elapsed_time = time.time() - start_time - collections_with_results = sum(1 for result in search_results if result) - logger.info( - f"Vector collection retrieval completed: Retrieved distances from " - f"{collections_with_results} collections in {elapsed_time:.2f}s" - ) - - search_context.set_distances_from_results(search_results) - - -async def _create_memory_fragment(search_context: TripletSearchContext) -> CogneeGraph: - """Creates memory fragment using search context properties.""" - relevant_node_ids = search_context.extract_relevant_node_ids() - return await get_memory_fragment( - properties_to_project=search_context.properties_to_project, - node_type=search_context.node_type, - node_name=search_context.node_name, - relevant_ids_to_filter=relevant_node_ids, - triplet_distance_penalty=search_context.triplet_distance_penalty, - ) - - -async def _map_distances_to_fragment( - search_context: TripletSearchContext, memory_fragment: CogneeGraph -): - """Maps vector distances from search context to memory fragment.""" await memory_fragment.map_vector_distances_to_graph_nodes( - node_distances=search_context.node_distances + node_distances=vector_search.node_distances ) await memory_fragment.map_vector_distances_to_graph_edges( - edge_distances=search_context.edge_distances + edge_distances=vector_search.edge_distances ) + return await memory_fragment.calculate_top_triplet_importances(k=top_k) + async def brute_force_triplet_search( query: str, @@ -229,28 +140,23 @@ async def brute_force_triplet_search( collections.append("EdgeType_relationship_name") try: - search_context = TripletSearchContext( - query=query, - top_k=top_k, - collections=collections, - properties_to_project=properties_to_project, - node_type=node_type, - node_name=node_name, - wide_search_limit=wide_search_limit, - triplet_distance_penalty=triplet_distance_penalty, - ) + vector_search = NodeEdgeVectorSearch() - await _embed_and_retrieve_distances(search_context) + await vector_search.embed_and_retrieve_distances(query, collections, wide_search_limit) - if not search_context.has_results(): + if not vector_search.has_results(): return [] - if memory_fragment is None: - memory_fragment = await _create_memory_fragment(search_context) - - await _map_distances_to_fragment(search_context, memory_fragment) - - return await memory_fragment.calculate_top_triplet_importances(k=search_context.top_k) + return await _get_top_triplet_importances( + memory_fragment, + vector_search, + properties_to_project, + node_type, + node_name, + triplet_distance_penalty, + wide_search_limit, + top_k, + ) except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s", diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py new file mode 100644 index 000000000..777751cf2 --- /dev/null +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -0,0 +1,81 @@ +import asyncio +import time +from typing import Any, List, Optional + +from cognee.shared.logging_utils import get_logger, ERROR +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError +from cognee.infrastructure.databases.vector import get_vector_engine + +logger = get_logger(level=ERROR) + + +class NodeEdgeVectorSearch: + """Manages vector search and distance retrieval for graph nodes and edges.""" + + def __init__(self, edge_collection: str = "EdgeType_relationship_name"): + self.edge_collection = edge_collection + self.query_vector: Optional[Any] = None + self.node_distances: dict[str, list[Any]] = {} + self.edge_distances: Optional[list[Any]] = None + + def has_results(self) -> bool: + """Checks if any collections returned results.""" + return bool(self.edge_distances) or any(self.node_distances.values()) + + def set_distances_from_results(self, collections: List[str], search_results: List[List[Any]]): + """Separates search results into node and edge distances.""" + self.node_distances = {} + for collection, result in zip(collections, search_results): + if collection == self.edge_collection: + self.edge_distances = result + else: + self.node_distances[collection] = result + + def extract_relevant_node_ids(self) -> List[str]: + """Extracts unique node IDs from search results.""" + relevant_node_ids = { + str(getattr(scored_node, "id")) + for score_collection in self.node_distances.values() + if isinstance(score_collection, (list, tuple)) + for scored_node in score_collection + if getattr(scored_node, "id", None) + } + return list(relevant_node_ids) + + async def embed_and_retrieve_distances( + self, query: str, collections: List[str], wide_search_limit: Optional[int] + ): + """Embeds query and retrieves vector distances from all collections.""" + vector_engine = get_vector_engine() + + query_embeddings = await vector_engine.embedding_engine.embed_text([query]) + self.query_vector = query_embeddings[0] + + start_time = time.time() + search_tasks = [ + self._search_single_collection(vector_engine, wide_search_limit, collection) + for collection in collections + ] + search_results = await asyncio.gather(*search_tasks) + + elapsed_time = time.time() - start_time + collections_with_results = sum(1 for result in search_results if result) + logger.info( + f"Vector collection retrieval completed: Retrieved distances from " + f"{collections_with_results} collections in {elapsed_time:.2f}s" + ) + + self.set_distances_from_results(collections, search_results) + + async def _search_single_collection( + self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str + ): + """Searches one collection and returns results or empty list if not found.""" + try: + return await vector_engine.search( + collection_name=collection_name, + query_vector=self.query_vector, + limit=wide_search_limit, + ) + except CollectionNotFoundError: + return []