diff --git a/lightrag/operate.py b/lightrag/operate.py index 99351097..9fefdb70 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2112,6 +2112,26 @@ async def _build_query_context( # Track chunk sources and metadata for final logging chunk_tracking = {} # chunk_id -> {source, frequency, order} + # Pre-compute query embedding if vector similarity method is used + kg_chunk_pick_method = text_chunks_db.global_config.get( + "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD + ) + query_embedding = None + if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: + embedding_func_config = text_chunks_db.embedding_func + if embedding_func_config and embedding_func_config.func: + try: + query_embedding = await embedding_func_config.func([query]) + query_embedding = query_embedding[ + 0 + ] # Extract first embedding from batch result + logger.debug( + "Pre-computed query embedding for vector similarity chunk selection" + ) + except Exception as e: + logger.warning(f"Failed to pre-compute query embedding: {e}") + query_embedding = None + # Handle local and global modes if query_param.mode == "local": local_entities, local_relations = await _get_node_data( @@ -2372,6 +2392,7 @@ async def _build_query_context( query, chunks_vdb, chunk_tracking=chunk_tracking, + query_embedding=query_embedding, ) # Find deduplcicated chunks from edge @@ -2385,6 +2406,7 @@ async def _build_query_context( query, chunks_vdb, chunk_tracking=chunk_tracking, + query_embedding=query_embedding, ) # Round-robin merge chunks from different sources with deduplication by chunk_id @@ -2719,6 +2741,7 @@ async def _find_related_text_unit_from_entities( query: str = None, chunks_vdb: BaseVectorStorage = None, chunk_tracking: dict = None, + query_embedding=None, ): """ Find text chunks related to entities using configurable chunk selection method. @@ -2814,6 +2837,7 @@ async def _find_related_text_unit_from_entities( num_of_chunks=num_of_chunks, entity_info=entities_with_chunks, embedding_func=actual_embedding_func, + query_embedding=query_embedding, ) if selected_chunk_ids == []: @@ -2971,6 +2995,7 @@ async def _find_related_text_unit_from_relations( query: str = None, chunks_vdb: BaseVectorStorage = None, chunk_tracking: dict = None, + query_embedding=None, ): """ Find text chunks related to relationships using configurable chunk selection method. @@ -3106,6 +3131,7 @@ async def _find_related_text_unit_from_relations( num_of_chunks=num_of_chunks, entity_info=relations_with_chunks, embedding_func=actual_embedding_func, + query_embedding=query_embedding, ) if selected_chunk_ids == []: diff --git a/lightrag/utils.py b/lightrag/utils.py index 979517b5..bec45f5f 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1774,6 +1774,7 @@ async def pick_by_vector_similarity( num_of_chunks: int, entity_info: list[dict[str, Any]], embedding_func: callable, + query_embedding=None, ) -> list[str]: """ Vector similarity-based text chunk selection algorithm. @@ -1818,11 +1819,19 @@ async def pick_by_vector_similarity( all_chunk_ids = list(all_chunk_ids) try: - # Get query embedding - query_embedding = await embedding_func([query]) - query_embedding = query_embedding[ - 0 - ] # Extract first embedding from batch result + # Use pre-computed query embedding if provided, otherwise compute it + if query_embedding is None: + query_embedding = await embedding_func([query]) + query_embedding = query_embedding[ + 0 + ] # Extract first embedding from batch result + logger.debug( + "Computed query embedding for vector similarity chunk selection" + ) + else: + logger.debug( + "Using pre-computed query embedding for vector similarity chunk selection" + ) # Get chunk embeddings from vector database chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids)