perf: add optional query_embedding parameter to avoid redundant embedding calls

This commit is contained in:
yangdx 2025-08-29 18:15:45 +08:00
parent a923d378dd
commit 03d0fa3014
8 changed files with 94 additions and 43 deletions

View file

@ -211,8 +211,17 @@ class BaseVectorStorage(StorageNameSpace, ABC):
meta_fields: set[str] = field(default_factory=set) meta_fields: set[str] = field(default_factory=set)
@abstractmethod @abstractmethod
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
"""Query the vector storage and retrieve top_k results.""" self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results.
Args:
query: The query string to search for
top_k: Number of top results to return
query_embedding: Optional pre-computed embedding for the query.
If provided, skips embedding computation for better performance.
"""
@abstractmethod @abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:

View file

@ -179,15 +179,21 @@ class FaissVectorDBStorage(BaseVectorStorage):
) )
return [m["__id__"] for m in list_data] return [m["__id__"] for m in list_data]
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
""" """
Search by a textual query; returns top_k results with their metadata + similarity distance. Search by a textual query; returns top_k results with their metadata + similarity distance.
""" """
embedding = await self.embedding_func( if query_embedding is not None:
[query], _priority=5 embedding = np.array([query_embedding], dtype=np.float32)
) # higher priority for query else:
# embedding is shape (1, dim) embedding = await self.embedding_func(
embedding = np.array(embedding, dtype=np.float32) [query], _priority=5
) # higher priority for query
# embedding is shape (1, dim)
embedding = np.array(embedding, dtype=np.float32)
faiss.normalize_L2(embedding) # we do in-place normalization faiss.normalize_L2(embedding) # we do in-place normalization
# Perform the similarity search # Perform the similarity search

View file

@ -1046,13 +1046,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
) )
return results return results
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
# Ensure collection is loaded before querying # Ensure collection is loaded before querying
self._ensure_collection_loaded() self._ensure_collection_loaded()
embedding = await self.embedding_func( # Use provided embedding or compute it
[query], _priority=5 if query_embedding is not None:
) # higher priority for query embedding = [query_embedding] # Milvus expects a list of embeddings
else:
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
# Include all meta_fields (created_at is now always included) # Include all meta_fields (created_at is now always included)
output_fields = list(self.meta_fields) output_fields = list(self.meta_fields)

View file

@ -1809,15 +1809,19 @@ class MongoVectorDBStorage(BaseVectorStorage):
return list_data return list_data
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, query_embedding: list[float] = None
) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search.""" """Queries the vector database using Atlas Vector Search."""
# Generate the embedding if query_embedding is not None:
embedding = await self.embedding_func( query_vector = query_embedding
[query], _priority=5 else:
) # higher priority for query # Generate the embedding
embedding = await self.embedding_func(
# Convert numpy array to a list to ensure compatibility with MongoDB [query], _priority=5
query_vector = embedding[0].tolist() ) # higher priority for query
# Convert numpy array to a list to ensure compatibility with MongoDB
query_vector = embedding[0].tolist()
# Define the aggregation pipeline with the converted query vector # Define the aggregation pipeline with the converted query vector
pipeline = [ pipeline = [

View file

@ -136,12 +136,18 @@ class NanoVectorDBStorage(BaseVectorStorage):
f"[{self.workspace}] embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" f"[{self.workspace}] embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
) )
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
# Execute embedding outside of lock to avoid improve cocurrent self, query: str, top_k: int, query_embedding: list[float] = None
embedding = await self.embedding_func( ) -> list[dict[str, Any]]:
[query], _priority=5 # Use provided embedding or compute it
) # higher priority for query if query_embedding is not None:
embedding = embedding[0] embedding = query_embedding
else:
# Execute embedding outside of lock to avoid improve cocurrent
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embedding[0]
client = await self._get_client() client = await self._get_client()
results = client.query( results = client.query(

View file

@ -2004,11 +2004,17 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data) await self.db.execute(upsert_sql, data)
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
embeddings = await self.embedding_func( self, query: str, top_k: int, query_embedding: list[float] = None
[query], _priority=5 ) -> list[dict[str, Any]]:
) # higher priority for query if query_embedding is not None:
embedding = embeddings[0] embedding = query_embedding
else:
embeddings = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding)) embedding_string = ",".join(map(str, embedding))
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)

View file

@ -199,13 +199,20 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
return results return results
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
embedding = await self.embedding_func( self, query: str, top_k: int, query_embedding: list[float] = None
[query], _priority=5 ) -> list[dict[str, Any]]:
) # higher priority for query if query_embedding is not None:
embedding = query_embedding
else:
embedding_result = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embedding_result[0]
results = self._client.search( results = self._client.search(
collection_name=self.final_namespace, collection_name=self.final_namespace,
query_vector=embedding[0], query_vector=embedding,
limit=top_k, limit=top_k,
with_payload=True, with_payload=True,
score_threshold=self.cosine_better_than_threshold, score_threshold=self.cosine_better_than_threshold,

View file

@ -2234,6 +2234,7 @@ async def _get_vector_context(
query: str, query: str,
chunks_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage,
query_param: QueryParam, query_param: QueryParam,
query_embedding: list[float] = None,
) -> list[dict]: ) -> list[dict]:
""" """
Retrieve text chunks from the vector database without reranking or truncation. Retrieve text chunks from the vector database without reranking or truncation.
@ -2245,6 +2246,7 @@ async def _get_vector_context(
query: The query string to search for query: The query string to search for
chunks_vdb: Vector database containing document chunks chunks_vdb: Vector database containing document chunks
query_param: Query parameters including chunk_top_k and ids query_param: Query parameters including chunk_top_k and ids
query_embedding: Optional pre-computed query embedding to avoid redundant embedding calls
Returns: Returns:
List of text chunks with metadata List of text chunks with metadata
@ -2253,7 +2255,9 @@ async def _get_vector_context(
# Use chunk_top_k if specified, otherwise fall back to top_k # Use chunk_top_k if specified, otherwise fall back to top_k
search_top_k = query_param.chunk_top_k or query_param.top_k search_top_k = query_param.chunk_top_k or query_param.top_k
results = await chunks_vdb.query(query, top_k=search_top_k) results = await chunks_vdb.query(
query, top_k=search_top_k, query_embedding=query_embedding
)
if not results: if not results:
logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})") logger.info(f"Naive query: 0 chunks (chunk_top_k: {search_top_k})")
return [] return []
@ -2291,6 +2295,10 @@ async def _build_query_context(
query_param: QueryParam, query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
): ):
if not query:
logger.warning("Query is empty, skipping context building")
return ""
logger.info(f"Process {os.getpid()} building query context...") logger.info(f"Process {os.getpid()} building query context...")
# Collect chunks from different sources separately # Collect chunks from different sources separately
@ -2309,12 +2317,12 @@ async def _build_query_context(
# Track chunk sources and metadata for final logging # Track chunk sources and metadata for final logging
chunk_tracking = {} # chunk_id -> {source, frequency, order} chunk_tracking = {} # chunk_id -> {source, frequency, order}
# Pre-compute query embedding if vector similarity method is used # Pre-compute query embedding once for all vector operations
kg_chunk_pick_method = text_chunks_db.global_config.get( kg_chunk_pick_method = text_chunks_db.global_config.get(
"kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
) )
query_embedding = None query_embedding = None
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
embedding_func_config = text_chunks_db.embedding_func embedding_func_config = text_chunks_db.embedding_func
if embedding_func_config and embedding_func_config.func: if embedding_func_config and embedding_func_config.func:
try: try:
@ -2322,9 +2330,7 @@ async def _build_query_context(
query_embedding = query_embedding[ query_embedding = query_embedding[
0 0
] # Extract first embedding from batch result ] # Extract first embedding from batch result
logger.debug( logger.debug("Pre-computed query embedding for all vector operations")
"Pre-computed query embedding for vector similarity chunk selection"
)
except Exception as e: except Exception as e:
logger.warning(f"Failed to pre-compute query embedding: {e}") logger.warning(f"Failed to pre-compute query embedding: {e}")
query_embedding = None query_embedding = None
@ -2368,6 +2374,7 @@ async def _build_query_context(
query, query,
chunks_vdb, chunks_vdb,
query_param, query_param,
query_embedding,
) )
# Track vector chunks with source metadata # Track vector chunks with source metadata
for i, chunk in enumerate(vector_chunks): for i, chunk in enumerate(vector_chunks):
@ -3429,7 +3436,7 @@ async def naive_query(
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
chunks = await _get_vector_context(query, chunks_vdb, query_param) chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0: if chunks is None or len(chunks) == 0:
return PROMPTS["fail_response"] return PROMPTS["fail_response"]