feat: add batch search to node_edge_vector_search.py
This commit is contained in:
parent
58dd518690
commit
701a92cdec
2 changed files with 100 additions and 24 deletions
|
|
@ -147,7 +147,9 @@ async def brute_force_triplet_search(
|
||||||
try:
|
try:
|
||||||
vector_search = NodeEdgeVectorSearch()
|
vector_search = NodeEdgeVectorSearch()
|
||||||
|
|
||||||
await vector_search.embed_and_retrieve_distances(query, collections, wide_search_limit)
|
await vector_search.embed_and_retrieve_distances(
|
||||||
|
query=query, collections=collections, wide_search_limit=wide_search_limit
|
||||||
|
)
|
||||||
|
|
||||||
if not vector_search.has_results():
|
if not vector_search.has_results():
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,9 @@ class NodeEdgeVectorSearch:
|
||||||
self.edge_collection = edge_collection
|
self.edge_collection = edge_collection
|
||||||
self.vector_engine = vector_engine or self._init_vector_engine()
|
self.vector_engine = vector_engine or self._init_vector_engine()
|
||||||
self.query_vector: Optional[Any] = None
|
self.query_vector: Optional[Any] = None
|
||||||
self.node_distances: dict[str, list[Any]] = {}
|
self.node_distances: dict[str, list[list[Any]]] = {}
|
||||||
self.edge_distances: Optional[list[Any]] = None
|
self.edge_distances: list[list[Any]] = []
|
||||||
|
self.query_list_length: Optional[int] = None
|
||||||
|
|
||||||
def _init_vector_engine(self):
|
def _init_vector_engine(self):
|
||||||
try:
|
try:
|
||||||
|
|
@ -28,26 +29,56 @@ class NodeEdgeVectorSearch:
|
||||||
|
|
||||||
def has_results(self) -> bool:
|
def has_results(self) -> bool:
|
||||||
"""Checks if any collections returned results."""
|
"""Checks if any collections returned results."""
|
||||||
return bool(self.edge_distances) or any(self.node_distances.values())
|
if self.query_list_length is None:
|
||||||
|
if self.edge_distances and any(self.edge_distances):
|
||||||
|
return True
|
||||||
|
return any(
|
||||||
|
bool(collection_results) for collection_results in self.node_distances.values()
|
||||||
|
)
|
||||||
|
|
||||||
def set_distances_from_results(self, collections: List[str], search_results: List[List[Any]]):
|
if self.edge_distances and any(self.edge_distances):
|
||||||
"""Separates search results into node and edge distances."""
|
return True
|
||||||
|
return any(
|
||||||
|
any(results_per_query for results_per_query in collection_results)
|
||||||
|
for collection_results in self.node_distances.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_distances_from_results(
|
||||||
|
self,
|
||||||
|
collections: List[str],
|
||||||
|
search_results: List[List[Any]],
|
||||||
|
query_list_length: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Separates search results into node and edge distances with stable shapes."""
|
||||||
self.node_distances = {}
|
self.node_distances = {}
|
||||||
|
self.edge_distances = (
|
||||||
|
[] if query_list_length is None else [[] for _ in range(query_list_length)]
|
||||||
|
)
|
||||||
for collection, result in zip(collections, search_results):
|
for collection, result in zip(collections, search_results):
|
||||||
if collection == self.edge_collection:
|
if not result:
|
||||||
self.edge_distances = result
|
empty_result = (
|
||||||
|
[] if query_list_length is None else [[] for _ in range(query_list_length)]
|
||||||
|
)
|
||||||
|
if collection == self.edge_collection:
|
||||||
|
self.edge_distances = empty_result
|
||||||
|
else:
|
||||||
|
self.node_distances[collection] = empty_result
|
||||||
else:
|
else:
|
||||||
self.node_distances[collection] = result
|
if collection == self.edge_collection:
|
||||||
|
self.edge_distances = result
|
||||||
|
else:
|
||||||
|
self.node_distances[collection] = result
|
||||||
|
|
||||||
def extract_relevant_node_ids(self) -> List[str]:
|
def extract_relevant_node_ids(self) -> List[str]:
|
||||||
"""Extracts unique node IDs from search results."""
|
"""Extracts unique node IDs from search results."""
|
||||||
relevant_node_ids = {
|
if self.query_list_length is not None:
|
||||||
str(getattr(scored_node, "id"))
|
return []
|
||||||
for score_collection in self.node_distances.values()
|
relevant_node_ids = set()
|
||||||
if isinstance(score_collection, (list, tuple))
|
for scored_results in self.node_distances.values():
|
||||||
for scored_node in score_collection
|
for scored_node in scored_results:
|
||||||
if getattr(scored_node, "id", None)
|
node_id = getattr(scored_node, "id", None)
|
||||||
}
|
if node_id:
|
||||||
|
relevant_node_ids.add(str(node_id))
|
||||||
return list(relevant_node_ids)
|
return list(relevant_node_ids)
|
||||||
|
|
||||||
async def _embed_query(self, query: str):
|
async def _embed_query(self, query: str):
|
||||||
|
|
@ -55,27 +86,70 @@ class NodeEdgeVectorSearch:
|
||||||
query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
|
query_embeddings = await self.vector_engine.embedding_engine.embed_text([query])
|
||||||
self.query_vector = query_embeddings[0]
|
self.query_vector = query_embeddings[0]
|
||||||
|
|
||||||
async def embed_and_retrieve_distances(
|
async def _run_batch_search(
|
||||||
self, query: str, collections: List[str], wide_search_limit: Optional[int]
|
self, collections: List[str], query_batch: List[str]
|
||||||
):
|
) -> List[List[Any]]:
|
||||||
"""Embeds query and retrieves vector distances from all collections."""
|
"""Runs batch search across all collections and returns list-of-lists per collection."""
|
||||||
await self._embed_query(query)
|
search_tasks = [
|
||||||
|
self._search_batch_collection(collection, query_batch) for collection in collections
|
||||||
|
]
|
||||||
|
return await asyncio.gather(*search_tasks)
|
||||||
|
|
||||||
start_time = time.time()
|
async def _search_batch_collection(
|
||||||
|
self, collection_name: str, query_batch: List[str]
|
||||||
|
) -> List[List[Any]]:
|
||||||
|
"""Searches one collection with batch queries and returns list-of-lists."""
|
||||||
|
try:
|
||||||
|
return await self.vector_engine.batch_search(
|
||||||
|
collection_name=collection_name, query_texts=query_batch, limit=None
|
||||||
|
)
|
||||||
|
except CollectionNotFoundError:
|
||||||
|
return [[]] * len(query_batch)
|
||||||
|
|
||||||
|
async def _run_single_search(
|
||||||
|
self, collections: List[str], query: str, wide_search_limit: Optional[int]
|
||||||
|
) -> List[List[Any]]:
|
||||||
|
"""Runs single query search and wraps results in list-of-lists for shape consistency."""
|
||||||
|
await self._embed_query(query)
|
||||||
search_tasks = [
|
search_tasks = [
|
||||||
self._search_single_collection(self.vector_engine, wide_search_limit, collection)
|
self._search_single_collection(self.vector_engine, wide_search_limit, collection)
|
||||||
for collection in collections
|
for collection in collections
|
||||||
]
|
]
|
||||||
search_results = await asyncio.gather(*search_tasks)
|
search_results = await asyncio.gather(*search_tasks)
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
async def embed_and_retrieve_distances(
|
||||||
|
self,
|
||||||
|
query: Optional[str] = None,
|
||||||
|
query_batch: Optional[List[str]] = None,
|
||||||
|
collections: List[str] = None,
|
||||||
|
wide_search_limit: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Embeds query/queries and retrieves vector distances from all collections."""
|
||||||
|
if query is not None and query_batch is not None:
|
||||||
|
raise ValueError("Cannot provide both 'query' and 'query_batch'; use exactly one.")
|
||||||
|
if query is None and query_batch is None:
|
||||||
|
raise ValueError("Must provide either 'query' or 'query_batch'.")
|
||||||
|
if not collections:
|
||||||
|
raise ValueError("'collections' must be a non-empty list.")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
if query_batch is not None:
|
||||||
|
self.query_list_length = len(query_batch)
|
||||||
|
search_results = await self._run_batch_search(collections, query_batch)
|
||||||
|
else:
|
||||||
|
self.query_list_length = None
|
||||||
|
search_results = await self._run_single_search(collections, query, wide_search_limit)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
collections_with_results = sum(1 for result in search_results if result)
|
collections_with_results = sum(1 for result in search_results if any(result))
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Vector collection retrieval completed: Retrieved distances from "
|
f"Vector collection retrieval completed: Retrieved distances from "
|
||||||
f"{collections_with_results} collections in {elapsed_time:.2f}s"
|
f"{collections_with_results} collections in {elapsed_time:.2f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.set_distances_from_results(collections, search_results)
|
self.set_distances_from_results(collections, search_results, self.query_list_length)
|
||||||
|
|
||||||
async def _search_single_collection(
|
async def _search_single_collection(
|
||||||
self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str
|
self, vector_engine: Any, wide_search_limit: Optional[int], collection_name: str
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue