optimize: avoid duplicate embedding calls in _build_query_context

Reduces API costs and improves query performance while maintaining backward compatibility.
This commit is contained in:
yangdx 2025-08-21 16:49:24 +08:00
parent 10bcf1479f
commit b5c230abdd
2 changed files with 40 additions and 5 deletions

View file

@ -2112,6 +2112,26 @@ async def _build_query_context(
# Track chunk sources and metadata for final logging
chunk_tracking = {} # chunk_id -> {source, frequency, order}
# Pre-compute query embedding if vector similarity method is used
kg_chunk_pick_method = text_chunks_db.global_config.get(
"kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
)
query_embedding = None
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
embedding_func_config = text_chunks_db.embedding_func
if embedding_func_config and embedding_func_config.func:
try:
query_embedding = await embedding_func_config.func([query])
query_embedding = query_embedding[
0
] # Extract first embedding from batch result
logger.debug(
"Pre-computed query embedding for vector similarity chunk selection"
)
except Exception as e:
logger.warning(f"Failed to pre-compute query embedding: {e}")
query_embedding = None
# Handle local and global modes
if query_param.mode == "local":
local_entities, local_relations = await _get_node_data(
@ -2372,6 +2392,7 @@ async def _build_query_context(
query,
chunks_vdb,
chunk_tracking=chunk_tracking,
query_embedding=query_embedding,
)
# Find deduplcicated chunks from edge
@ -2385,6 +2406,7 @@ async def _build_query_context(
query,
chunks_vdb,
chunk_tracking=chunk_tracking,
query_embedding=query_embedding,
)
# Round-robin merge chunks from different sources with deduplication by chunk_id
@ -2719,6 +2741,7 @@ async def _find_related_text_unit_from_entities(
query: str = None,
chunks_vdb: BaseVectorStorage = None,
chunk_tracking: dict = None,
query_embedding=None,
):
"""
Find text chunks related to entities using configurable chunk selection method.
@ -2814,6 +2837,7 @@ async def _find_related_text_unit_from_entities(
num_of_chunks=num_of_chunks,
entity_info=entities_with_chunks,
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
if selected_chunk_ids == []:
@ -2971,6 +2995,7 @@ async def _find_related_text_unit_from_relations(
query: str = None,
chunks_vdb: BaseVectorStorage = None,
chunk_tracking: dict = None,
query_embedding=None,
):
"""
Find text chunks related to relationships using configurable chunk selection method.
@ -3106,6 +3131,7 @@ async def _find_related_text_unit_from_relations(
num_of_chunks=num_of_chunks,
entity_info=relations_with_chunks,
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
if selected_chunk_ids == []:

View file

@ -1774,6 +1774,7 @@ async def pick_by_vector_similarity(
num_of_chunks: int,
entity_info: list[dict[str, Any]],
embedding_func: callable,
query_embedding=None,
) -> list[str]:
"""
Vector similarity-based text chunk selection algorithm.
@ -1818,11 +1819,19 @@ async def pick_by_vector_similarity(
all_chunk_ids = list(all_chunk_ids)
try:
# Get query embedding
query_embedding = await embedding_func([query])
query_embedding = query_embedding[
0
] # Extract first embedding from batch result
# Use pre-computed query embedding if provided, otherwise compute it
if query_embedding is None:
query_embedding = await embedding_func([query])
query_embedding = query_embedding[
0
] # Extract first embedding from batch result
logger.debug(
"Computed query embedding for vector similarity chunk selection"
)
else:
logger.debug(
"Using pre-computed query embedding for vector similarity chunk selection"
)
# Get chunk embeddings from vector database
chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids)