diff --git a/cognee/modules/retrieval/utils/node_edge_vector_search.py b/cognee/modules/retrieval/utils/node_edge_vector_search.py index 80116f6f2..558b9bc0c 100644 --- a/cognee/modules/retrieval/utils/node_edge_vector_search.py +++ b/cognee/modules/retrieval/utils/node_edge_vector_search.py @@ -27,6 +27,39 @@ class NodeEdgeVectorSearch: logger.error("Failed to initialize vector engine: %s", e) raise RuntimeError("Initialization error") from e + async def embed_and_retrieve_distances( + self, + query: Optional[str] = None, + query_batch: Optional[List[str]] = None, + collections: List[str] = None, + wide_search_limit: Optional[int] = None, + ): + """Embeds query/queries and retrieves vector distances from all collections.""" + if query is not None and query_batch is not None: + raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") + if query is None and query_batch is None: + raise ValueError("Must provide either 'query' or 'query_batch'.") + if not collections: + raise ValueError("'collections' must be a non-empty list.") + + start_time = time.time() + + if query_batch is not None: + self.query_list_length = len(query_batch) + search_results = await self._run_batch_search(collections, query_batch) + else: + self.query_list_length = None + search_results = await self._run_single_search(collections, query, wide_search_limit) + + elapsed_time = time.time() - start_time + collections_with_results = sum(1 for result in search_results if any(result)) + logger.info( + f"Vector collection retrieval completed: Retrieved distances from " + f"{collections_with_results} collections in {elapsed_time:.2f}s" + ) + + self.set_distances_from_results(collections, search_results, self.query_list_length) + def has_results(self) -> bool: """Checks if any collections returned results.""" if self.query_list_length is None: @@ -43,6 +76,18 @@ class NodeEdgeVectorSearch: for collection_results in self.node_distances.values() ) + def extract_relevant_node_ids(self) -> List[str]: + """Extracts unique node IDs from search results.""" + if self.query_list_length is not None: + return [] + relevant_node_ids = set() + for scored_results in self.node_distances.values(): + for scored_node in scored_results: + node_id = getattr(scored_node, "id", None) + if node_id: + relevant_node_ids.add(str(node_id)) + return list(relevant_node_ids) + def set_distances_from_results( self, collections: List[str], @@ -74,23 +119,6 @@ class NodeEdgeVectorSearch: else: self.node_distances[collection] = result - def extract_relevant_node_ids(self) -> List[str]: - """Extracts unique node IDs from search results.""" - if self.query_list_length is not None: - return [] - relevant_node_ids = set() - for scored_results in self.node_distances.values(): - for scored_node in scored_results: - node_id = getattr(scored_node, "id", None) - if node_id: - relevant_node_ids.add(str(node_id)) - return list(relevant_node_ids) - - async def _embed_query(self, query: str): - """Embeds the query and stores the resulting vector.""" - query_embeddings = await self.vector_engine.embedding_engine.embed_text([query]) - self.query_vector = query_embeddings[0] - async def _run_batch_search( self, collections: List[str], query_batch: List[str] ) -> List[List[Any]]: @@ -127,38 +155,10 @@ class NodeEdgeVectorSearch: search_results = await asyncio.gather(*search_tasks) return search_results - async def embed_and_retrieve_distances( - self, - query: Optional[str] = None, - query_batch: Optional[List[str]] = None, - collections: List[str] = None, - wide_search_limit: Optional[int] = None, - ): - """Embeds query/queries and retrieves vector distances from all collections.""" - if query is not None and query_batch is not None: - raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.") - if query is None and query_batch is None: - raise ValueError("Must provide either 'query' or 'query_batch'.") - if not collections: - raise ValueError("'collections' must be a non-empty list.") - - start_time = time.time() - - if query_batch is not None: - self.query_list_length = len(query_batch) - search_results = await self._run_batch_search(collections, query_batch) - else: - self.query_list_length = None - search_results = await self._run_single_search(collections, query, wide_search_limit) - - elapsed_time = time.time() - start_time - collections_with_results = sum(1 for result in search_results if any(result)) - logger.info( - f"Vector collection retrieval completed: Retrieved distances from " - f"{collections_with_results} collections in {elapsed_time:.2f}s" - ) - - self.set_distances_from_results(collections, search_results, self.query_list_length) + async def _embed_query(self, query: str): + """Embeds the query and stores the resulting vector.""" + query_embeddings = await self.vector_engine.embedding_engine.embed_text([query]) + self.query_vector = query_embeddings[0] async def _search_single_collection( self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str