refactor: improve methods order
This commit is contained in:
parent
872795f0cc
commit
c609b73cda
1 changed files with 49 additions and 49 deletions
|
|
@ -27,6 +27,39 @@ class NodeEdgeVectorSearch:
|
|||
logger.error("Failed to initialize vector engine: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
async def embed_and_retrieve_distances(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
collections: List[str] = None,
|
||||
wide_search_limit: Optional[int] = None,
|
||||
):
|
||||
"""Embeds query/queries and retrieves vector distances from all collections."""
|
||||
if query is not None and query_batch is not None:
|
||||
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
|
||||
if query is None and query_batch is None:
|
||||
raise ValueError("Must provide either 'query' or 'query_batch'.")
|
||||
if not collections:
|
||||
raise ValueError("'collections' must be a non-empty list.")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if query_batch is not None:
|
||||
self.query_list_length = len(query_batch)
|
||||
search_results = await self._run_batch_search(collections, query_batch)
|
||||
else:
|
||||
self.query_list_length = None
|
||||
search_results = await self._run_single_search(collections, query, wide_search_limit)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
collections_with_results = sum(1 for result in search_results if any(result))
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from "
|
||||
f"{collections_with_results} collections in {elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
self.set_distances_from_results(collections, search_results, self.query_list_length)
|
||||
|
||||
def has_results(self) -> bool:
|
||||
"""Checks if any collections returned results."""
|
||||
if self.query_list_length is None:
|
||||
|
|
@ -43,6 +76,18 @@ class NodeEdgeVectorSearch:
|
|||
for collection_results in self.node_distances.values()
|
||||
)
|
||||
|
||||
def extract_relevant_node_ids(self) -> List[str]:
|
||||
"""Extracts unique node IDs from search results."""
|
||||
if self.query_list_length is not None:
|
||||
return []
|
||||
relevant_node_ids = set()
|
||||
for scored_results in self.node_distances.values():
|
||||
for scored_node in scored_results:
|
||||
node_id = getattr(scored_node, "id", None)
|
||||
if node_id:
|
||||
relevant_node_ids.add(str(node_id))
|
||||
return list(relevant_node_ids)
|
||||
|
||||
def set_distances_from_results(
|
||||
self,
|
||||
collections: List[str],
|
||||
|
|
@ -74,23 +119,6 @@ class NodeEdgeVectorSearch:
|
|||
else:
|
||||
self.node_distances[collection] = result
|
||||
|
||||
def extract_relevant_node_ids(self) -> List[str]:
|
||||
"""Extracts unique node IDs from search results."""
|
||||
if self.query_list_length is not None:
|
||||
return []
|
||||
relevant_node_ids = set()
|
||||
for scored_results in self.node_distances.values():
|
||||
for scored_node in scored_results:
|
||||
node_id = getattr(scored_node, "id", None)
|
||||
if node_id:
|
||||
relevant_node_ids.add(str(node_id))
|
||||
return list(relevant_node_ids)
|
||||
|
||||
async def _embed_query(self, query: str):
|
||||
"""Embeds the query and stores the resulting vector."""
|
||||
query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
|
||||
self.query_vector = query_embeddings[0]
|
||||
|
||||
async def _run_batch_search(
|
||||
self, collections: List[str], query_batch: List[str]
|
||||
) -> List[List[Any]]:
|
||||
|
|
@ -127,38 +155,10 @@ class NodeEdgeVectorSearch:
|
|||
search_results = await asyncio.gather(*search_tasks)
|
||||
return search_results
|
||||
|
||||
async def embed_and_retrieve_distances(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
query_batch: Optional[List[str]] = None,
|
||||
collections: List[str] = None,
|
||||
wide_search_limit: Optional[int] = None,
|
||||
):
|
||||
"""Embeds query/queries and retrieves vector distances from all collections."""
|
||||
if query is not None and query_batch is not None:
|
||||
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
|
||||
if query is None and query_batch is None:
|
||||
raise ValueError("Must provide either 'query' or 'query_batch'.")
|
||||
if not collections:
|
||||
raise ValueError("'collections' must be a non-empty list.")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if query_batch is not None:
|
||||
self.query_list_length = len(query_batch)
|
||||
search_results = await self._run_batch_search(collections, query_batch)
|
||||
else:
|
||||
self.query_list_length = None
|
||||
search_results = await self._run_single_search(collections, query, wide_search_limit)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
collections_with_results = sum(1 for result in search_results if any(result))
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from "
|
||||
f"{collections_with_results} collections in {elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
self.set_distances_from_results(collections, search_results, self.query_list_length)
|
||||
async def _embed_query(self, query: str):
|
||||
"""Embeds the query and stores the resulting vector."""
|
||||
query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
|
||||
self.query_vector = query_embeddings[0]
|
||||
|
||||
async def _search_single_collection(
|
||||
self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue