From c295d355a0525871971c8e19bc9fb75e6a50a5d6 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 8 Jul 2025 15:05:30 +0800 Subject: [PATCH] fix chunk_top_k limiting --- examples/rerank_example.py | 7 +++++++ lightrag/operate.py | 17 +++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/examples/rerank_example.py b/examples/rerank_example.py index 74ec85bc..e0e361a5 100644 --- a/examples/rerank_example.py +++ b/examples/rerank_example.py @@ -20,6 +20,7 @@ from lightrag import LightRAG, QueryParam from lightrag.rerank import custom_rerank, RerankModel from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc, setup_logger +from lightrag.kg.shared_storage import initialize_pipeline_status # Set up your working directory WORKING_DIR = "./test_rerank" @@ -87,6 +88,9 @@ async def create_rag_with_rerank(): rerank_model_func=my_rerank_func, ) + await rag.initialize_storages() + await initialize_pipeline_status() + return rag @@ -120,6 +124,9 @@ async def create_rag_with_rerank_model(): rerank_model_func=rerank_model.rerank, ) + await rag.initialize_storages() + await initialize_pipeline_status() + return rag diff --git a/lightrag/operate.py b/lightrag/operate.py index f9f53285..05fef78e 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2823,8 +2823,9 @@ async def apply_rerank_if_enabled( documents=retrieved_docs, top_k=top_k, ) - if reranked_docs and len(reranked_docs) > 0: + if len(reranked_docs) > top_k: + reranked_docs = reranked_docs[:top_k] logger.info( f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}" ) @@ -2846,7 +2847,7 @@ async def process_chunks_unified( source_type: str = "mixed", ) -> list[dict]: """ - Unified processing for text chunks: deduplication, reranking, and token truncation. + Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation. Args: query: Search query for reranking @@ -2874,7 +2875,15 @@ async def process_chunks_unified( f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})" ) - # 2. Apply reranking if enabled and query is provided + # 2. Apply chunk_top_k limiting if specified + if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: + if len(unique_chunks) > query_param.chunk_top_k: + unique_chunks = unique_chunks[: query_param.chunk_top_k] + logger.debug( + f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})" + ) + + # 3. Apply reranking if enabled and query is provided if global_config.get("enable_rerank", False) and query and unique_chunks: rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks) unique_chunks = await apply_rerank_if_enabled( @@ -2885,7 +2894,7 @@ async def process_chunks_unified( ) logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})") - # 3. Token-based final truncation + # 4. Token-based final truncation tokenizer = global_config.get("tokenizer") if tokenizer and unique_chunks: original_count = len(unique_chunks)