fix chunk_top_k limiting
This commit is contained in:
parent
04a57445da
commit
c295d355a0
2 changed files with 20 additions and 4 deletions
|
|
@ -20,6 +20,7 @@ from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.rerank import custom_rerank, RerankModel
|
from lightrag.rerank import custom_rerank, RerankModel
|
||||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||||
from lightrag.utils import EmbeddingFunc, setup_logger
|
from lightrag.utils import EmbeddingFunc, setup_logger
|
||||||
|
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||||
|
|
||||||
# Set up your working directory
|
# Set up your working directory
|
||||||
WORKING_DIR = "./test_rerank"
|
WORKING_DIR = "./test_rerank"
|
||||||
|
|
@ -87,6 +88,9 @@ async def create_rag_with_rerank():
|
||||||
rerank_model_func=my_rerank_func,
|
rerank_model_func=my_rerank_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await rag.initialize_storages()
|
||||||
|
await initialize_pipeline_status()
|
||||||
|
|
||||||
return rag
|
return rag
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -120,6 +124,9 @@ async def create_rag_with_rerank_model():
|
||||||
rerank_model_func=rerank_model.rerank,
|
rerank_model_func=rerank_model.rerank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await rag.initialize_storages()
|
||||||
|
await initialize_pipeline_status()
|
||||||
|
|
||||||
return rag
|
return rag
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2823,8 +2823,9 @@ async def apply_rerank_if_enabled(
|
||||||
documents=retrieved_docs,
|
documents=retrieved_docs,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
if reranked_docs and len(reranked_docs) > 0:
|
if reranked_docs and len(reranked_docs) > 0:
|
||||||
|
if len(reranked_docs) > top_k:
|
||||||
|
reranked_docs = reranked_docs[:top_k]
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
|
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
|
||||||
)
|
)
|
||||||
|
|
@ -2846,7 +2847,7 @@ async def process_chunks_unified(
|
||||||
source_type: str = "mixed",
|
source_type: str = "mixed",
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Unified processing for text chunks: deduplication, reranking, and token truncation.
|
Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Search query for reranking
|
query: Search query for reranking
|
||||||
|
|
@ -2874,7 +2875,15 @@ async def process_chunks_unified(
|
||||||
f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
|
f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Apply reranking if enabled and query is provided
|
# 2. Apply chunk_top_k limiting if specified
|
||||||
|
if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
|
||||||
|
if len(unique_chunks) > query_param.chunk_top_k:
|
||||||
|
unique_chunks = unique_chunks[: query_param.chunk_top_k]
|
||||||
|
logger.debug(
|
||||||
|
f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Apply reranking if enabled and query is provided
|
||||||
if global_config.get("enable_rerank", False) and query and unique_chunks:
|
if global_config.get("enable_rerank", False) and query and unique_chunks:
|
||||||
rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
|
rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
|
||||||
unique_chunks = await apply_rerank_if_enabled(
|
unique_chunks = await apply_rerank_if_enabled(
|
||||||
|
|
@ -2885,7 +2894,7 @@ async def process_chunks_unified(
|
||||||
)
|
)
|
||||||
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
||||||
|
|
||||||
# 3. Token-based final truncation
|
# 4. Token-based final truncation
|
||||||
tokenizer = global_config.get("tokenizer")
|
tokenizer = global_config.get("tokenizer")
|
||||||
if tokenizer and unique_chunks:
|
if tokenizer and unique_chunks:
|
||||||
original_count = len(unique_chunks)
|
original_count = len(unique_chunks)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue