From f9cb490ad96073e424489821eb5dd273ed966336 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Thu, 8 Jan 2026 17:20:28 +0100 Subject: [PATCH] refactor: brute_force_triplet_search.py --- .../utils/brute_force_triplet_search.py | 216 ++++++++++++------ 1 file changed, 141 insertions(+), 75 deletions(-) diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index a70fa661b..5f367ca7f 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import List, Optional, Type +from typing import Any, List, Optional, Type from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError @@ -9,13 +9,12 @@ from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge -from cognee.modules.users.models import User -from cognee.shared.utils import send_telemetry logger = get_logger(level=ERROR) def format_triplets(edges): + """Formats edges into human-readable triplet strings.""" triplets = [] for edge in edges: node1 = edge.node1 @@ -51,7 +50,6 @@ async def get_memory_fragment( try: graph_engine = await get_graph_engine() - await memory_fragment.project_graph_from_db( graph_engine, node_properties_to_project=properties_to_project, @@ -61,18 +59,142 @@ async def get_memory_fragment( relevant_ids_to_filter=relevant_ids_to_filter, triplet_distance_penalty=triplet_distance_penalty, ) - except EntityNotFoundError: - # This is expected behavior - continue with empty fragment pass except Exception as e: logger.error(f"Error during memory fragment creation: {str(e)}") - # Still return the fragment even if projection failed - pass return memory_fragment +class _BruteForceTripletSearchEngine: + """Internal search engine for brute force triplet search operations.""" + + def __init__( + self, + query: str, + top_k: int, + collections: List[str], + properties_to_project: Optional[List[str]], + memory_fragment: Optional[CogneeGraph], + node_type: Optional[Type], + node_name: Optional[List[str]], + wide_search_limit: Optional[int], + triplet_distance_penalty: float, + ): + self.query = query + self.top_k = top_k + self.collections = collections + self.properties_to_project = properties_to_project + self.memory_fragment = memory_fragment + self.node_type = node_type + self.node_name = node_name + self.wide_search_limit = wide_search_limit + self.triplet_distance_penalty = triplet_distance_penalty + self.vector_engine = self._load_vector_engine() + self.query_vector = None + self.node_distances = None + self.edge_distances = None + + async def search(self) -> List[Edge]: + """Orchestrates the brute force triplet search workflow.""" + await self._embed_query_text() + await self._retrieve_and_set_vector_distances() + + if not (self.edge_distances or any(self.node_distances.values())): + return [] + + await self._ensure_memory_fragment_is_loaded() + await self._map_distances_to_memory_fragment() + + return await self.memory_fragment.calculate_top_triplet_importances(k=self.top_k) + + def _load_vector_engine(self): + """Loads the vector engine instance.""" + try: + return get_vector_engine() + except Exception as e: + logger.error("Failed to initialize vector engine: %s", e) + raise RuntimeError("Initialization error") from e + + async def _embed_query_text(self): + """Converts query text into embedding vector.""" + query_embeddings = await self.vector_engine.embedding_engine.embed_text([self.query]) + self.query_vector = query_embeddings[0] + + async def _retrieve_and_set_vector_distances(self): + """Searches all collections in parallel and sets node/edge distances directly.""" + start_time = time.time() + search_results = await self._run_parallel_collection_searches() + elapsed_time = time.time() - start_time + + collections_with_results = sum(1 for result in search_results if result) + logger.info( + f"Vector collection retrieval completed: Retrieved distances from " + f"{collections_with_results} collections in {elapsed_time:.2f}s" + ) + + self.node_distances = {} + for collection, result in zip(self.collections, search_results): + if collection == "EdgeType_relationship_name": + self.edge_distances = result + else: + self.node_distances[collection] = result + + async def _run_parallel_collection_searches(self) -> List[List[Any]]: + """Executes vector searches across all collections concurrently.""" + search_tasks = [ + self._search_single_collection(collection_name) for collection_name in self.collections + ] + return await asyncio.gather(*search_tasks) + + async def _search_single_collection(self, collection_name: str): + """Searches one collection and returns results or empty list if not found.""" + try: + return await self.vector_engine.search( + collection_name=collection_name, + query_vector=self.query_vector, + limit=self.wide_search_limit, + ) + except CollectionNotFoundError: + return [] + + async def _ensure_memory_fragment_is_loaded(self): + """Loads memory fragment if not already provided.""" + if self.memory_fragment is None: + relevant_node_ids = self._extract_relevant_node_ids_for_filtering() + self.memory_fragment = await get_memory_fragment( + properties_to_project=self.properties_to_project, + node_type=self.node_type, + node_name=self.node_name, + relevant_ids_to_filter=relevant_node_ids, + triplet_distance_penalty=self.triplet_distance_penalty, + ) + + def _extract_relevant_node_ids_for_filtering(self) -> Optional[List[str]]: + """Extracts unique node IDs from search results to filter graph projection.""" + if self.wide_search_limit is None: + return None + + relevant_node_ids = { + str(getattr(scored_node, "id")) + for score_collection in self.node_distances.values() + if isinstance(score_collection, (list, tuple)) + for scored_node in score_collection + if getattr(scored_node, "id", None) + } + return list(relevant_node_ids) + + async def _map_distances_to_memory_fragment(self): + """Maps vector distances to nodes and edges in the memory fragment.""" + await self.memory_fragment.map_vector_distances_to_graph_nodes( + node_distances=self.node_distances + ) + await self.memory_fragment.map_vector_distances_to_graph_edges( + edge_distances=self.edge_distances + ) + + async def brute_force_triplet_search( query: str, top_k: int = 5, @@ -108,7 +230,6 @@ async def brute_force_triplet_search( # Setting wide search limit based on the parameters non_global_search = node_name is None - wide_search_limit = wide_search_top_k if non_global_search else None if collections is None: @@ -123,73 +244,18 @@ async def brute_force_triplet_search( collections.append("EdgeType_relationship_name") try: - vector_engine = get_vector_engine() - except Exception as e: - logger.error("Failed to initialize vector engine: %s", e) - raise RuntimeError("Initialization error") from e - - query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0] - - async def search_in_collection(collection_name: str): - try: - return await vector_engine.search( - collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit - ) - except CollectionNotFoundError: - return [] - - try: - start_time = time.time() - - results = await asyncio.gather( - *[search_in_collection(collection_name) for collection_name in collections] + engine = _BruteForceTripletSearchEngine( + query=query, + top_k=top_k, + collections=collections, + properties_to_project=properties_to_project, + memory_fragment=memory_fragment, + node_type=node_type, + node_name=node_name, + wide_search_limit=wide_search_limit, + triplet_distance_penalty=triplet_distance_penalty, ) - - if all(not item for item in results): - return [] - - # Final statistics - vector_collection_search_time = time.time() - start_time - logger.info( - f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s" - ) - - node_distances = {collection: result for collection, result in zip(collections, results)} - - edge_distances = node_distances.get("EdgeType_relationship_name", None) - - if wide_search_limit is not None: - relevant_ids_to_filter = list( - { - str(getattr(scored_node, "id")) - for collection_name, score_collection in node_distances.items() - if collection_name != "EdgeType_relationship_name" - and isinstance(score_collection, (list, tuple)) - for scored_node in score_collection - if getattr(scored_node, "id", None) - } - ) - else: - relevant_ids_to_filter = None - - if memory_fragment is None: - memory_fragment = await get_memory_fragment( - properties_to_project=properties_to_project, - node_type=node_type, - node_name=node_name, - relevant_ids_to_filter=relevant_ids_to_filter, - triplet_distance_penalty=triplet_distance_penalty, - ) - - await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) - await memory_fragment.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - - results = await memory_fragment.calculate_top_triplet_importances(k=top_k) - - return results - - except CollectionNotFoundError: - return [] + return await engine.search() except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s",