Refactor: move reranking utilities from operate.py to utils.py
• Move apply_rerank_if_enabled to utils • Move process_chunks_unified to utils
This commit is contained in:
parent
2c940f0728
commit
3075691f72
2 changed files with 122 additions and 126 deletions
|
|
@ -27,6 +27,7 @@ from .utils import (
|
||||||
update_chunk_cache_list,
|
update_chunk_cache_list,
|
||||||
remove_think_tags,
|
remove_think_tags,
|
||||||
linear_gradient_weighted_polling,
|
linear_gradient_weighted_polling,
|
||||||
|
process_chunks_unified,
|
||||||
)
|
)
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
|
|
@ -3215,128 +3216,3 @@ async def query_with_keywords(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown mode {param.mode}")
|
raise ValueError(f"Unknown mode {param.mode}")
|
||||||
|
|
||||||
|
|
||||||
async def apply_rerank_if_enabled(
|
|
||||||
query: str,
|
|
||||||
retrieved_docs: list[dict],
|
|
||||||
global_config: dict,
|
|
||||||
enable_rerank: bool = True,
|
|
||||||
top_n: int = None,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Apply reranking to retrieved documents if rerank is enabled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The search query
|
|
||||||
retrieved_docs: List of retrieved documents
|
|
||||||
global_config: Global configuration containing rerank settings
|
|
||||||
enable_rerank: Whether to enable reranking from query parameter
|
|
||||||
top_n: Number of top documents to return after reranking
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Reranked documents if rerank is enabled, otherwise original documents
|
|
||||||
"""
|
|
||||||
if not enable_rerank or not retrieved_docs:
|
|
||||||
return retrieved_docs
|
|
||||||
|
|
||||||
rerank_func = global_config.get("rerank_model_func")
|
|
||||||
if not rerank_func:
|
|
||||||
logger.warning(
|
|
||||||
"Rerank is enabled but no rerank model is configured. Please set up a rerank model or set enable_rerank=False in query parameters."
|
|
||||||
)
|
|
||||||
return retrieved_docs
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Apply reranking - let rerank_model_func handle top_k internally
|
|
||||||
reranked_docs = await rerank_func(
|
|
||||||
query=query,
|
|
||||||
documents=retrieved_docs,
|
|
||||||
top_n=top_n,
|
|
||||||
)
|
|
||||||
if reranked_docs and len(reranked_docs) > 0:
|
|
||||||
if len(reranked_docs) > top_n:
|
|
||||||
reranked_docs = reranked_docs[:top_n]
|
|
||||||
logger.info(
|
|
||||||
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
|
|
||||||
)
|
|
||||||
return reranked_docs
|
|
||||||
else:
|
|
||||||
logger.warning("Rerank returned empty results, using original documents")
|
|
||||||
return retrieved_docs
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during reranking: {e}, using original documents")
|
|
||||||
return retrieved_docs
|
|
||||||
|
|
||||||
|
|
||||||
async def process_chunks_unified(
|
|
||||||
query: str,
|
|
||||||
unique_chunks: list[dict],
|
|
||||||
query_param: QueryParam,
|
|
||||||
global_config: dict,
|
|
||||||
source_type: str = "mixed",
|
|
||||||
chunk_token_limit: int = None, # Add parameter for dynamic token limit
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Search query for reranking
|
|
||||||
chunks: List of text chunks to process
|
|
||||||
query_param: Query parameters containing configuration
|
|
||||||
global_config: Global configuration dictionary
|
|
||||||
source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
|
|
||||||
chunk_token_limit: Dynamic token limit for chunks (if None, uses default)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Processed and filtered list of text chunks
|
|
||||||
"""
|
|
||||||
if not unique_chunks:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 1. Apply reranking if enabled and query is provided
|
|
||||||
if query_param.enable_rerank and query and unique_chunks:
|
|
||||||
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
|
|
||||||
unique_chunks = await apply_rerank_if_enabled(
|
|
||||||
query=query,
|
|
||||||
retrieved_docs=unique_chunks,
|
|
||||||
global_config=global_config,
|
|
||||||
enable_rerank=query_param.enable_rerank,
|
|
||||||
top_n=rerank_top_k,
|
|
||||||
)
|
|
||||||
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
|
||||||
|
|
||||||
# 2. Apply chunk_top_k limiting if specified
|
|
||||||
if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
|
|
||||||
if len(unique_chunks) > query_param.chunk_top_k:
|
|
||||||
unique_chunks = unique_chunks[: query_param.chunk_top_k]
|
|
||||||
logger.debug(
|
|
||||||
f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Token-based final truncation
|
|
||||||
tokenizer = global_config.get("tokenizer")
|
|
||||||
if tokenizer and unique_chunks:
|
|
||||||
# Set default chunk_token_limit if not provided
|
|
||||||
if chunk_token_limit is None:
|
|
||||||
# Get default from query_param or global_config
|
|
||||||
chunk_token_limit = getattr(
|
|
||||||
query_param,
|
|
||||||
"max_total_tokens",
|
|
||||||
global_config.get("MAX_TOTAL_TOKENS", 32000),
|
|
||||||
)
|
|
||||||
|
|
||||||
original_count = len(unique_chunks)
|
|
||||||
unique_chunks = truncate_list_by_token_size(
|
|
||||||
unique_chunks,
|
|
||||||
key=lambda x: x.get("content", ""),
|
|
||||||
max_token_size=chunk_token_limit,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
|
|
||||||
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return unique_chunks
|
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ def get_env_value(
|
||||||
|
|
||||||
# Use TYPE_CHECKING to avoid circular imports
|
# Use TYPE_CHECKING to avoid circular imports
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from lightrag.base import BaseKVStorage
|
from lightrag.base import BaseKVStorage, QueryParam
|
||||||
|
|
||||||
# use the .env that is inside the current folder
|
# use the .env that is inside the current folder
|
||||||
# allows to use different .env file for each lightrag instance
|
# allows to use different .env file for each lightrag instance
|
||||||
|
|
@ -1777,3 +1777,123 @@ class TokenTracker:
|
||||||
f"Completion tokens: {usage['completion_tokens']}, "
|
f"Completion tokens: {usage['completion_tokens']}, "
|
||||||
f"Total tokens: {usage['total_tokens']}"
|
f"Total tokens: {usage['total_tokens']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def apply_rerank_if_enabled(
|
||||||
|
query: str,
|
||||||
|
retrieved_docs: list[dict],
|
||||||
|
global_config: dict,
|
||||||
|
enable_rerank: bool = True,
|
||||||
|
top_n: int = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Apply reranking to retrieved documents if rerank is enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query
|
||||||
|
retrieved_docs: List of retrieved documents
|
||||||
|
global_config: Global configuration containing rerank settings
|
||||||
|
enable_rerank: Whether to enable reranking from query parameter
|
||||||
|
top_n: Number of top documents to return after reranking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reranked documents if rerank is enabled, otherwise original documents
|
||||||
|
"""
|
||||||
|
if not enable_rerank or not retrieved_docs:
|
||||||
|
return retrieved_docs
|
||||||
|
|
||||||
|
rerank_func = global_config.get("rerank_model_func")
|
||||||
|
if not rerank_func:
|
||||||
|
logger.warning(
|
||||||
|
"Rerank is enabled but no rerank model is configured. Please set up a rerank model or set enable_rerank=False in query parameters."
|
||||||
|
)
|
||||||
|
return retrieved_docs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply reranking - let rerank_model_func handle top_k internally
|
||||||
|
reranked_docs = await rerank_func(
|
||||||
|
query=query,
|
||||||
|
documents=retrieved_docs,
|
||||||
|
top_n=top_n,
|
||||||
|
)
|
||||||
|
if reranked_docs and len(reranked_docs) > 0:
|
||||||
|
if len(reranked_docs) > top_n:
|
||||||
|
reranked_docs = reranked_docs[:top_n]
|
||||||
|
logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks")
|
||||||
|
return reranked_docs
|
||||||
|
else:
|
||||||
|
logger.warning("Rerank returned empty results, using original chunks")
|
||||||
|
return retrieved_docs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during reranking: {e}, using original chunks")
|
||||||
|
return retrieved_docs
|
||||||
|
|
||||||
|
|
||||||
|
async def process_chunks_unified(
|
||||||
|
query: str,
|
||||||
|
unique_chunks: list[dict],
|
||||||
|
query_param: "QueryParam",
|
||||||
|
global_config: dict,
|
||||||
|
source_type: str = "mixed",
|
||||||
|
chunk_token_limit: int = None, # Add parameter for dynamic token limit
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query for reranking
|
||||||
|
chunks: List of text chunks to process
|
||||||
|
query_param: Query parameters containing configuration
|
||||||
|
global_config: Global configuration dictionary
|
||||||
|
source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
|
||||||
|
chunk_token_limit: Dynamic token limit for chunks (if None, uses default)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed and filtered list of text chunks
|
||||||
|
"""
|
||||||
|
if not unique_chunks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 1. Apply reranking if enabled and query is provided
|
||||||
|
if query_param.enable_rerank and query and unique_chunks:
|
||||||
|
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
|
||||||
|
unique_chunks = await apply_rerank_if_enabled(
|
||||||
|
query=query,
|
||||||
|
retrieved_docs=unique_chunks,
|
||||||
|
global_config=global_config,
|
||||||
|
enable_rerank=query_param.enable_rerank,
|
||||||
|
top_n=rerank_top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Apply chunk_top_k limiting if specified
|
||||||
|
if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
|
||||||
|
if len(unique_chunks) > query_param.chunk_top_k:
|
||||||
|
unique_chunks = unique_chunks[: query_param.chunk_top_k]
|
||||||
|
logger.info(f"Kept chunk_top-k: {len(unique_chunks)} chunks")
|
||||||
|
|
||||||
|
# 3. Token-based final truncation
|
||||||
|
tokenizer = global_config.get("tokenizer")
|
||||||
|
if tokenizer and unique_chunks:
|
||||||
|
# Set default chunk_token_limit if not provided
|
||||||
|
if chunk_token_limit is None:
|
||||||
|
# Get default from query_param or global_config
|
||||||
|
chunk_token_limit = getattr(
|
||||||
|
query_param,
|
||||||
|
"max_total_tokens",
|
||||||
|
global_config.get("MAX_TOTAL_TOKENS", 32000),
|
||||||
|
)
|
||||||
|
|
||||||
|
original_count = len(unique_chunks)
|
||||||
|
unique_chunks = truncate_list_by_token_size(
|
||||||
|
unique_chunks,
|
||||||
|
key=lambda x: x.get("content", ""),
|
||||||
|
max_token_size=chunk_token_limit,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
|
||||||
|
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return unique_chunks
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue