From 876120853f11bbaf1fd66d65d23d113d067928c1 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Fri, 9 Jan 2026 09:10:29 +0100 Subject: [PATCH] refactor: brute_force_triplet_search.py with context class --- .../utils/brute_force_triplet_search.py | 185 +++++++++--------- 1 file changed, 90 insertions(+), 95 deletions(-) diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 5f367ca7f..50d16edb2 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -23,12 +23,10 @@ def format_triplets(edges): node1_attributes = node1.attributes node2_attributes = node2.attributes - # Filter only non-None properties node1_info = {key: value for key, value in node1_attributes.items() if value is not None} node2_info = {key: value for key, value in node2_attributes.items() if value is not None} edge_info = {key: value for key, value in edge_attributes.items() if value is not None} - # Create the formatted triplet triplet = f"Node1: {node1_info}\nEdge: {edge_info}\nNode2: {node2_info}\n\n\n" triplets.append(triplet) @@ -67,8 +65,8 @@ async def get_memory_fragment( return memory_fragment -class _BruteForceTripletSearchEngine: - """Internal search engine for brute force triplet search operations.""" +class TripletSearchContext: + """Pure state container for triplet search operations.""" def __init__( self, @@ -76,7 +74,6 @@ class _BruteForceTripletSearchEngine: top_k: int, collections: List[str], properties_to_project: Optional[List[str]], - memory_fragment: Optional[CogneeGraph], node_type: Optional[Type], node_name: Optional[List[str]], wide_search_limit: Optional[int], @@ -86,92 +83,20 @@ class _BruteForceTripletSearchEngine: self.top_k = top_k self.collections = collections self.properties_to_project = properties_to_project - self.memory_fragment = memory_fragment self.node_type = node_type self.node_name = node_name self.wide_search_limit = wide_search_limit self.triplet_distance_penalty = triplet_distance_penalty - self.vector_engine = self._load_vector_engine() + self.query_vector = None self.node_distances = None self.edge_distances = None - async def search(self) -> List[Edge]: - """Orchestrates the brute force triplet search workflow.""" - await self._embed_query_text() - await self._retrieve_and_set_vector_distances() + def has_results(self) -> bool: + """Checks if any collections returned results.""" + return bool(self.edge_distances or any(self.node_distances.values())) - if not (self.edge_distances or any(self.node_distances.values())): - return [] - - await self._ensure_memory_fragment_is_loaded() - await self._map_distances_to_memory_fragment() - - return await self.memory_fragment.calculate_top_triplet_importances(k=self.top_k) - - def _load_vector_engine(self): - """Loads the vector engine instance.""" - try: - return get_vector_engine() - except Exception as e: - logger.error("Failed to initialize vector engine: %s", e) - raise RuntimeError("Initialization error") from e - - async def _embed_query_text(self): - """Converts query text into embedding vector.""" - query_embeddings = await self.vector_engine.embedding_engine.embed_text([self.query]) - self.query_vector = query_embeddings[0] - - async def _retrieve_and_set_vector_distances(self): - """Searches all collections in parallel and sets node/edge distances directly.""" - start_time = time.time() - search_results = await self._run_parallel_collection_searches() - elapsed_time = time.time() - start_time - - collections_with_results = sum(1 for result in search_results if result) - logger.info( - f"Vector collection retrieval completed: Retrieved distances from " - f"{collections_with_results} collections in {elapsed_time:.2f}s" - ) - - self.node_distances = {} - for collection, result in zip(self.collections, search_results): - if collection == "EdgeType_relationship_name": - self.edge_distances = result - else: - self.node_distances[collection] = result - - async def _run_parallel_collection_searches(self) -> List[List[Any]]: - """Executes vector searches across all collections concurrently.""" - search_tasks = [ - self._search_single_collection(collection_name) for collection_name in self.collections - ] - return await asyncio.gather(*search_tasks) - - async def _search_single_collection(self, collection_name: str): - """Searches one collection and returns results or empty list if not found.""" - try: - return await self.vector_engine.search( - collection_name=collection_name, - query_vector=self.query_vector, - limit=self.wide_search_limit, - ) - except CollectionNotFoundError: - return [] - - async def _ensure_memory_fragment_is_loaded(self): - """Loads memory fragment if not already provided.""" - if self.memory_fragment is None: - relevant_node_ids = self._extract_relevant_node_ids_for_filtering() - self.memory_fragment = await get_memory_fragment( - properties_to_project=self.properties_to_project, - node_type=self.node_type, - node_name=self.node_name, - relevant_ids_to_filter=relevant_node_ids, - triplet_distance_penalty=self.triplet_distance_penalty, - ) - - def _extract_relevant_node_ids_for_filtering(self) -> Optional[List[str]]: + def extract_relevant_node_ids(self) -> Optional[List[str]]: """Extracts unique node IDs from search results to filter graph projection.""" if self.wide_search_limit is None: return None @@ -185,14 +110,76 @@ class _BruteForceTripletSearchEngine: } return list(relevant_node_ids) - async def _map_distances_to_memory_fragment(self): - """Maps vector distances to nodes and edges in the memory fragment.""" - await self.memory_fragment.map_vector_distances_to_graph_nodes( - node_distances=self.node_distances - ) - await self.memory_fragment.map_vector_distances_to_graph_edges( - edge_distances=self.edge_distances + def set_distances_from_results(self, search_results: List[List[Any]]): + """Separates search results into node and edge distances.""" + self.node_distances = {} + for collection, result in zip(self.collections, search_results): + if collection == "EdgeType_relationship_name": + self.edge_distances = result + else: + self.node_distances[collection] = result + + +async def _search_single_collection( + vector_engine: Any, search_context: TripletSearchContext, collection_name: str +): + """Searches one collection and returns results or empty list if not found.""" + try: + return await vector_engine.search( + collection_name=collection_name, + query_vector=search_context.query_vector, + limit=search_context.wide_search_limit, ) + except CollectionNotFoundError: + return [] + + +async def _embed_and_retrieve_distances(search_context: TripletSearchContext): + """Embeds query and retrieves vector distances from all collections.""" + vector_engine = get_vector_engine() + + query_embeddings = await vector_engine.embedding_engine.embed_text([search_context.query]) + search_context.query_vector = query_embeddings[0] + + start_time = time.time() + search_tasks = [ + _search_single_collection(vector_engine, search_context, collection) + for collection in search_context.collections + ] + search_results = await asyncio.gather(*search_tasks) + + elapsed_time = time.time() - start_time + collections_with_results = sum(1 for result in search_results if result) + logger.info( + f"Vector collection retrieval completed: Retrieved distances from " + f"{collections_with_results} collections in {elapsed_time:.2f}s" + ) + + search_context.set_distances_from_results(search_results) + + +async def _create_memory_fragment(search_context: TripletSearchContext) -> CogneeGraph: + """Creates memory fragment using search context properties.""" + relevant_node_ids = search_context.extract_relevant_node_ids() + return await get_memory_fragment( + properties_to_project=search_context.properties_to_project, + node_type=search_context.node_type, + node_name=search_context.node_name, + relevant_ids_to_filter=relevant_node_ids, + triplet_distance_penalty=search_context.triplet_distance_penalty, + ) + + +async def _map_distances_to_fragment( + search_context: TripletSearchContext, memory_fragment: CogneeGraph +): + """Maps vector distances from search context to memory fragment.""" + await memory_fragment.map_vector_distances_to_graph_nodes( + node_distances=search_context.node_distances + ) + await memory_fragment.map_vector_distances_to_graph_edges( + edge_distances=search_context.edge_distances + ) async def brute_force_triplet_search( @@ -228,9 +215,7 @@ async def brute_force_triplet_search( if top_k <= 0: raise ValueError("top_k must be a positive integer.") - # Setting wide search limit based on the parameters - non_global_search = node_name is None - wide_search_limit = wide_search_top_k if non_global_search else None + wide_search_limit = wide_search_top_k if node_name is None else None if collections is None: collections = [ @@ -244,18 +229,28 @@ async def brute_force_triplet_search( collections.append("EdgeType_relationship_name") try: - engine = _BruteForceTripletSearchEngine( + search_context = TripletSearchContext( query=query, top_k=top_k, collections=collections, properties_to_project=properties_to_project, - memory_fragment=memory_fragment, node_type=node_type, node_name=node_name, wide_search_limit=wide_search_limit, triplet_distance_penalty=triplet_distance_penalty, ) - return await engine.search() + + await _embed_and_retrieve_distances(search_context) + + if not search_context.has_results(): + return [] + + if memory_fragment is None: + memory_fragment = await _create_memory_fragment(search_context) + + await _map_distances_to_fragment(search_context, memory_fragment) + + return await memory_fragment.calculate_top_triplet_importances(k=search_context.top_k) except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s",