Merge pull request #1848 from HKUDS/context-builder
Refac: optimizing context builder to improve query performance
This commit is contained in:
commit
733adb420d
6 changed files with 634 additions and 550 deletions
|
|
@ -48,7 +48,7 @@ OLLAMA_EMULATING_MODEL_TAG=latest
|
|||
########################
|
||||
### Query Configuration
|
||||
########################
|
||||
# LLM responde cache for query (Not valid for streaming response
|
||||
# LLM responde cache for query (Not valid for streaming response)
|
||||
ENABLE_LLM_CACHE=true
|
||||
# HISTORY_TURNS=0
|
||||
# COSINE_THRESHOLD=0.2
|
||||
|
|
@ -62,8 +62,8 @@ ENABLE_LLM_CACHE=true
|
|||
# MAX_RELATION_TOKENS=10000
|
||||
### control the maximum tokens send to LLM (include entities, raltions and chunks)
|
||||
# MAX_TOTAL_TOKENS=32000
|
||||
### maxumium related chunks grab from single entity or relations
|
||||
# RELATED_CHUNK_NUMBER=10
|
||||
### maximum number of related chunks per source entity or relation (higher values increase re-ranking time)
|
||||
# RELATED_CHUNK_NUMBER=5
|
||||
|
||||
### Reranker configuration (Set ENABLE_RERANK to true in reranking model is configed)
|
||||
ENABLE_RERANK=False
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
__api_version__ = "0188"
|
||||
__api_version__ = "0189"
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ DEFAULT_MAX_TOTAL_TOKENS = 32000
|
|||
DEFAULT_HISTORY_TURNS = 0
|
||||
DEFAULT_ENABLE_RERANK = True
|
||||
DEFAULT_COSINE_THRESHOLD = 0.2
|
||||
DEFAULT_RELATED_CHUNK_NUMBER = 10
|
||||
DEFAULT_RELATED_CHUNK_NUMBER = 5
|
||||
|
||||
# Separator for graph fields
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
|
|
|||
|
|
@ -536,7 +536,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|||
# Load the collection if it's not already loaded
|
||||
# In Milvus, collections need to be loaded before they can be searched
|
||||
self._client.load_collection(self.namespace)
|
||||
logger.debug(f"Collection {self.namespace} loaded successfully")
|
||||
# logger.debug(f"Collection {self.namespace} loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load collection {self.namespace}: {e}")
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -55,7 +55,7 @@ def get_env_value(
|
|||
|
||||
# Use TYPE_CHECKING to avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from lightrag.base import BaseKVStorage
|
||||
from lightrag.base import BaseKVStorage, QueryParam
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
|
|
@ -85,8 +85,10 @@ def verbose_debug(msg: str, *args, **kwargs):
|
|||
formatted_msg = msg
|
||||
# Then truncate the formatted message
|
||||
truncated_msg = (
|
||||
formatted_msg[:100] + "..." if len(formatted_msg) > 100 else formatted_msg
|
||||
formatted_msg[:150] + "..." if len(formatted_msg) > 150 else formatted_msg
|
||||
)
|
||||
# Remove consecutive newlines
|
||||
truncated_msg = re.sub(r"\n+", "\n", truncated_msg)
|
||||
logger.debug(truncated_msg, **kwargs)
|
||||
|
||||
|
||||
|
|
@ -777,39 +779,6 @@ def truncate_list_by_token_size(
|
|||
return list_data
|
||||
|
||||
|
||||
def process_combine_contexts(*context_lists):
|
||||
"""
|
||||
Combine multiple context lists and remove duplicate content
|
||||
|
||||
Args:
|
||||
*context_lists: Any number of context lists
|
||||
|
||||
Returns:
|
||||
Combined context list with duplicates removed
|
||||
"""
|
||||
seen_content = {}
|
||||
combined_data = []
|
||||
|
||||
# Iterate through all input context lists
|
||||
for context_list in context_lists:
|
||||
if not context_list: # Skip empty lists
|
||||
continue
|
||||
for item in context_list:
|
||||
content_dict = {
|
||||
k: v for k, v in item.items() if k != "id" and k != "created_at"
|
||||
}
|
||||
content_key = tuple(sorted(content_dict.items()))
|
||||
if content_key not in seen_content:
|
||||
seen_content[content_key] = item
|
||||
combined_data.append(item)
|
||||
|
||||
# Reassign IDs
|
||||
for i, item in enumerate(combined_data):
|
||||
item["id"] = str(i + 1)
|
||||
|
||||
return combined_data
|
||||
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
"""Calculate cosine similarity between two vectors"""
|
||||
dot_product = np.dot(v1, v2)
|
||||
|
|
@ -1673,6 +1642,86 @@ def check_storage_env_vars(storage_name: str) -> None:
|
|||
)
|
||||
|
||||
|
||||
def linear_gradient_weighted_polling(
|
||||
entities_or_relations: list[dict],
|
||||
max_related_chunks: int,
|
||||
min_related_chunks: int = 1,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Linear gradient weighted polling algorithm for text chunk selection.
|
||||
|
||||
This algorithm ensures that entities/relations with higher importance get more text chunks,
|
||||
forming a linear decreasing allocation pattern.
|
||||
|
||||
Args:
|
||||
entities_or_relations: List of entities or relations sorted by importance (high to low)
|
||||
max_related_chunks: Expected number of text chunks for the highest importance entity/relation
|
||||
min_related_chunks: Expected number of text chunks for the lowest importance entity/relation
|
||||
|
||||
Returns:
|
||||
List of selected text chunk IDs
|
||||
"""
|
||||
if not entities_or_relations:
|
||||
return []
|
||||
|
||||
n = len(entities_or_relations)
|
||||
if n == 1:
|
||||
# Only one entity/relation, return its first max_related_chunks text chunks
|
||||
entity_chunks = entities_or_relations[0].get("sorted_chunks", [])
|
||||
return entity_chunks[:max_related_chunks]
|
||||
|
||||
# Calculate expected text chunk count for each position (linear decrease)
|
||||
expected_counts = []
|
||||
for i in range(n):
|
||||
# Linear interpolation: from max_related_chunks to min_related_chunks
|
||||
ratio = i / (n - 1) if n > 1 else 0
|
||||
expected = max_related_chunks - ratio * (
|
||||
max_related_chunks - min_related_chunks
|
||||
)
|
||||
expected_counts.append(int(round(expected)))
|
||||
|
||||
# First round allocation: allocate by expected values
|
||||
selected_chunks = []
|
||||
used_counts = [] # Track number of chunks used by each entity
|
||||
total_remaining = 0 # Accumulate remaining quotas
|
||||
|
||||
for i, entity_rel in enumerate(entities_or_relations):
|
||||
entity_chunks = entity_rel.get("sorted_chunks", [])
|
||||
expected = expected_counts[i]
|
||||
|
||||
# Actual allocatable count
|
||||
actual = min(expected, len(entity_chunks))
|
||||
selected_chunks.extend(entity_chunks[:actual])
|
||||
used_counts.append(actual)
|
||||
|
||||
# Accumulate remaining quota
|
||||
remaining = expected - actual
|
||||
if remaining > 0:
|
||||
total_remaining += remaining
|
||||
|
||||
# Second round allocation: multi-round scanning to allocate remaining quotas
|
||||
for _ in range(total_remaining):
|
||||
allocated = False
|
||||
|
||||
# Scan entities one by one, allocate one chunk when finding unused chunks
|
||||
for i, entity_rel in enumerate(entities_or_relations):
|
||||
entity_chunks = entity_rel.get("sorted_chunks", [])
|
||||
|
||||
# Check if there are still unused chunks
|
||||
if used_counts[i] < len(entity_chunks):
|
||||
# Allocate one chunk
|
||||
selected_chunks.append(entity_chunks[used_counts[i]])
|
||||
used_counts[i] += 1
|
||||
allocated = True
|
||||
break
|
||||
|
||||
# If no chunks were allocated in this round, all entities are exhausted
|
||||
if not allocated:
|
||||
break
|
||||
|
||||
return selected_chunks
|
||||
|
||||
|
||||
class TokenTracker:
|
||||
"""Track token usage for LLM calls."""
|
||||
|
||||
|
|
@ -1728,3 +1777,127 @@ class TokenTracker:
|
|||
f"Completion tokens: {usage['completion_tokens']}, "
|
||||
f"Total tokens: {usage['total_tokens']}"
|
||||
)
|
||||
|
||||
|
||||
async def apply_rerank_if_enabled(
|
||||
query: str,
|
||||
retrieved_docs: list[dict],
|
||||
global_config: dict,
|
||||
enable_rerank: bool = True,
|
||||
top_n: int = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Apply reranking to retrieved documents if rerank is enabled.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
retrieved_docs: List of retrieved documents
|
||||
global_config: Global configuration containing rerank settings
|
||||
enable_rerank: Whether to enable reranking from query parameter
|
||||
top_n: Number of top documents to return after reranking
|
||||
|
||||
Returns:
|
||||
Reranked documents if rerank is enabled, otherwise original documents
|
||||
"""
|
||||
if not enable_rerank or not retrieved_docs:
|
||||
return retrieved_docs
|
||||
|
||||
rerank_func = global_config.get("rerank_model_func")
|
||||
if not rerank_func:
|
||||
logger.warning(
|
||||
"Rerank is enabled but no rerank model is configured. Please set up a rerank model or set enable_rerank=False in query parameters."
|
||||
)
|
||||
return retrieved_docs
|
||||
|
||||
try:
|
||||
# Apply reranking - let rerank_model_func handle top_k internally
|
||||
reranked_docs = await rerank_func(
|
||||
query=query,
|
||||
documents=retrieved_docs,
|
||||
top_n=top_n,
|
||||
)
|
||||
if reranked_docs and len(reranked_docs) > 0:
|
||||
if len(reranked_docs) > top_n:
|
||||
reranked_docs = reranked_docs[:top_n]
|
||||
logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks")
|
||||
return reranked_docs
|
||||
else:
|
||||
logger.warning("Rerank returned empty results, using original chunks")
|
||||
return retrieved_docs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during reranking: {e}, using original chunks")
|
||||
return retrieved_docs
|
||||
|
||||
|
||||
async def process_chunks_unified(
|
||||
query: str,
|
||||
unique_chunks: list[dict],
|
||||
query_param: "QueryParam",
|
||||
global_config: dict,
|
||||
source_type: str = "mixed",
|
||||
chunk_token_limit: int = None, # Add parameter for dynamic token limit
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
|
||||
|
||||
Args:
|
||||
query: Search query for reranking
|
||||
chunks: List of text chunks to process
|
||||
query_param: Query parameters containing configuration
|
||||
global_config: Global configuration dictionary
|
||||
source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
|
||||
chunk_token_limit: Dynamic token limit for chunks (if None, uses default)
|
||||
|
||||
Returns:
|
||||
Processed and filtered list of text chunks
|
||||
"""
|
||||
if not unique_chunks:
|
||||
return []
|
||||
|
||||
origin_count = len(unique_chunks)
|
||||
|
||||
# 1. Apply reranking if enabled and query is provided
|
||||
if query_param.enable_rerank and query and unique_chunks:
|
||||
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
|
||||
unique_chunks = await apply_rerank_if_enabled(
|
||||
query=query,
|
||||
retrieved_docs=unique_chunks,
|
||||
global_config=global_config,
|
||||
enable_rerank=query_param.enable_rerank,
|
||||
top_n=rerank_top_k,
|
||||
)
|
||||
|
||||
# 2. Apply chunk_top_k limiting if specified
|
||||
if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
|
||||
if len(unique_chunks) > query_param.chunk_top_k:
|
||||
unique_chunks = unique_chunks[: query_param.chunk_top_k]
|
||||
logger.debug(
|
||||
f"Kept chunk_top-k: {len(unique_chunks)} chunks (deduplicated original: {origin_count})"
|
||||
)
|
||||
|
||||
# 3. Token-based final truncation
|
||||
tokenizer = global_config.get("tokenizer")
|
||||
if tokenizer and unique_chunks:
|
||||
# Set default chunk_token_limit if not provided
|
||||
if chunk_token_limit is None:
|
||||
# Get default from query_param or global_config
|
||||
chunk_token_limit = getattr(
|
||||
query_param,
|
||||
"max_total_tokens",
|
||||
global_config.get("MAX_TOTAL_TOKENS", 32000),
|
||||
)
|
||||
|
||||
original_count = len(unique_chunks)
|
||||
unique_chunks = truncate_list_by_token_size(
|
||||
unique_chunks,
|
||||
key=lambda x: x.get("content", ""),
|
||||
max_token_size=chunk_token_limit,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
logger.debug(
|
||||
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
|
||||
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"
|
||||
)
|
||||
|
||||
return unique_chunks
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue