diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 3c3603f01..a39ef50e1 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -147,7 +147,9 @@ async def brute_force_triplet_search( try: vector_search = NodeEdgeVectorSearch() - await vector_search.embed_and_retrieve_distances(query, collections, wide_search_limit) + await vector_search.embed_and_retrieve_distances( + query=query, collections=collections, wide_search_limit=wide_search_limit + ) if not vector_search.has_results(): return [] diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py index 08f76218c..e8dd0dc48 100644 --- a/cognee/modules/retrieval/utils/node_edge_vector_search.py +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -16,8 +16,9 @@ class NodeEdgeVectorSearch: self.edge_collection = edge_collection self.vector_engine = vector_engine or self._init_vector_engine() self.query_vector: Optional[Any] = None - self.node_distances: dict[str, list[Any]] = {} - self.edge_distances: Optional[list[Any]] = None + self.node_distances: dict[str, list[list[Any]]] = {} + self.edge_distances: list[list[Any]] = [] + self.query_list_length: Optional[int] = None def _init_vector_engine(self): try: @@ -28,26 +29,56 @@ class NodeEdgeVectorSearch: def has_results(self) -> bool: """Checks if any collections returned results.""" - return bool(self.edge_distances) or any(self.node_distances.values()) + if self.query_list_length is None: + if self.edge_distances and any(self.edge_distances): + return True + return any( + bool(collection_results) for collection_results in self.node_distances.values() + ) - def set_distances_from_results(self, collections: List[str], search_results: List[List[Any]]): - """Separates search results into node and edge distances.""" + if self.edge_distances and any(self.edge_distances): + return True + return any( + any(results_per_query for results_per_query in collection_results) + for collection_results in self.node_distances.values() + ) + + def set_distances_from_results( + self, + collections: List[str], + search_results: List[List[Any]], + query_list_length: Optional[int] = None, + ): + """Separates search results into node and edge distances with stable shapes.""" self.node_distances = {} + self.edge_distances = ( + [] if query_list_length is None else [[] for _ in range(query_list_length)] + ) for collection, result in zip(collections, search_results): - if collection == self.edge_collection: - self.edge_distances = result + if not result: + empty_result = ( + [] if query_list_length is None else [[] for _ in range(query_list_length)] + ) + if collection == self.edge_collection: + self.edge_distances = empty_result + else: + self.node_distances[collection] = empty_result else: - self.node_distances[collection] = result + if collection == self.edge_collection: + self.edge_distances = result + else: + self.node_distances[collection] = result def extract_relevant_node_ids(self) -> List[str]: """Extracts unique node IDs from search results.""" - relevant_node_ids = { - str(getattr(scored_node, "id")) - for score_collection in self.node_distances.values() - if isinstance(score_collection, (list, tuple)) - for scored_node in score_collection - if getattr(scored_node, "id", None) - } + if self.query_list_length is not None: + return [] + relevant_node_ids = set() + for scored_results in self.node_distances.values(): + for scored_node in scored_results: + node_id = getattr(scored_node, "id", None) + if node_id: + relevant_node_ids.add(str(node_id)) return list(relevant_node_ids) async def _embed_query(self, query: str): @@ -55,27 +86,70 @@ class NodeEdgeVectorSearch: query_embeddings = await self.vector_engine.embedding_engine.embed_text([query]) self.query_vector = query_embeddings[0] - async def embed_and_retrieve_distances( - self, query: str, collections: List[str], wide_search_limit: Optional[int] - ): - """Embeds query and retrieves vector distances from all collections.""" - await self._embed_query(query) + async def _run_batch_search( + self, collections: List[str], query_batch: List[str] + ) -> List[List[Any]]: + """Runs batch search across all collections and returns list-of-lists per collection.""" + search_tasks = [ + self._search_batch_collection(collection, query_batch) for collection in collections + ] + return await asyncio.gather(*search_tasks) - start_time = time.time() + async def _search_batch_collection( + self, collection_name: str, query_batch: List[str] + ) -> List[List[Any]]: + """Searches one collection with batch queries and returns list-of-lists.""" + try: + return await self.vector_engine.batch_search( + collection_name=collection_name, query_texts=query_batch, limit=None + ) + except CollectionNotFoundError: + return [[]] * len(query_batch) + + async def _run_single_search( + self, collections: List[str], query: str, wide_search_limit: Optional[int] + ) -> List[List[Any]]: + """Runs single query search and wraps results in list-of-lists for shape consistency.""" + await self._embed_query(query) search_tasks = [ self._search_single_collection(self.vector_engine, wide_search_limit, collection) for collection in collections ] search_results = await asyncio.gather(*search_tasks) + return search_results + + async def embed_and_retrieve_distances( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + collections: List[str] = None, + wide_search_limit: Optional[int] = None, + ): + """Embeds query/queries and retrieves vector distances from all collections.""" + if query is not None and query_batch is not None: + raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") + if query is None and query_batch is None: + raise ValueError("Must provide either 'query' or 'query_batch'.") + if not collections: + raise ValueError("'collections' must be a non-empty list.") + + start_time = time.time() + + if query_batch is not None: + self.query_list_length = len(query_batch) + search_results = await self._run_batch_search(collections, query_batch) + else: + self.query_list_length = None + search_results = await self._run_single_search(collections, query, wide_search_limit) elapsed_time = time.time() - start_time - collections_with_results = sum(1 for result in search_results if result) + collections_with_results = sum(1 for result in search_results if any(result)) logger.info( f"Vector collection retrieval completed: Retrieved distances from " f"{collections_with_results} collections in {elapsed_time:.2f}s" ) - self.set_distances_from_results(collections, search_results) + self.set_distances_from_results(collections, search_results, self.query_list_length) async def _search_single_collection( self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str