update chunks truncation method

This commit is contained in:
zrguo 2025-07-08 13:31:05 +08:00
parent f5c80d7cde
commit 04a57445da
5 changed files with 211 additions and 180 deletions

View file

@ -294,6 +294,16 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""

View file

@ -153,7 +153,7 @@ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_d
python examples/lightrag_openai_demo.py python examples/lightrag_openai_demo.py
``` ```
For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample codes LLM and embedding configurations accordingly. For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code's LLM and embedding configurations accordingly.
**Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory. **Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory.
@ -300,6 +300,16 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""

View file

@ -46,7 +46,9 @@ OLLAMA_EMULATING_MODEL_TAG=latest
# HISTORY_TURNS=3 # HISTORY_TURNS=3
# COSINE_THRESHOLD=0.2 # COSINE_THRESHOLD=0.2
# TOP_K=60 # TOP_K=60
# MAX_TOKEN_TEXT_CHUNK=4000 # CHUNK_TOP_K=5
# CHUNK_RERANK_TOP_K=5
# MAX_TOKEN_TEXT_CHUNK=6000
# MAX_TOKEN_RELATION_DESC=4000 # MAX_TOKEN_RELATION_DESC=4000
# MAX_TOKEN_ENTITY_DESC=4000 # MAX_TOKEN_ENTITY_DESC=4000

View file

@ -60,7 +60,17 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "6000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = int( max_token_for_global_context: int = int(
@ -280,21 +290,6 @@ class BaseKVStorage(StorageNameSpace, ABC):
False: if the cache drop failed, or the cache mode is not supported False: if the cache drop failed, or the cache mode is not supported
""" """
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
# """Delete specific cache records from storage by chunk IDs
# Importance notes for in-memory storage:
# 1. Changes will be persisted to disk during the next index_done_callback
# 2. update flags to notify other processes that data persistence is needed
# Args:
# chunk_ids (list[str]): List of chunk IDs to be dropped from storage
# Returns:
# True: if the cache drop successfully
# False: if the cache drop failed, or the operation is not supported
# """
@dataclass @dataclass
class BaseGraphStorage(StorageNameSpace, ABC): class BaseGraphStorage(StorageNameSpace, ABC):

View file

@ -1526,6 +1526,7 @@ async def kg_query(
# Build context # Build context
context = await _build_query_context( context = await _build_query_context(
query,
ll_keywords_str, ll_keywords_str,
hl_keywords_str, hl_keywords_str,
knowledge_graph_inst, knowledge_graph_inst,
@ -1744,93 +1745,52 @@ async def _get_vector_context(
query: str, query: str,
chunks_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage,
query_param: QueryParam, query_param: QueryParam,
tokenizer: Tokenizer, ) -> list[dict]:
) -> tuple[list, list, list] | None:
""" """
Retrieve vector context from the vector database. Retrieve text chunks from the vector database without reranking or truncation.
This function performs vector search to find relevant text chunks for a query, This function performs vector search to find relevant text chunks for a query.
formats them with file path and creation time information. Reranking and truncation will be handled later in the unified processing.
Args: Args:
query: The query string to search for query: The query string to search for
chunks_vdb: Vector database containing document chunks chunks_vdb: Vector database containing document chunks
query_param: Query parameters including top_k and ids query_param: Query parameters including chunk_top_k and ids
tokenizer: Tokenizer for counting tokens
Returns: Returns:
Tuple (empty_entities, empty_relations, text_units) for combine_contexts, List of text chunks with metadata
compatible with _get_edge_data and _get_node_data format
""" """
try: try:
results = await chunks_vdb.query( # Use chunk_top_k if specified, otherwise fall back to top_k
query, top_k=query_param.top_k, ids=query_param.ids search_top_k = query_param.chunk_top_k or query_param.top_k
)
results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids)
if not results: if not results:
return [], [], [] return []
valid_chunks = [] valid_chunks = []
for result in results: for result in results:
if "content" in result: if "content" in result:
# Directly use content from chunks_vdb.query result chunk_with_metadata = {
chunk_with_time = {
"content": result["content"], "content": result["content"],
"created_at": result.get("created_at", None), "created_at": result.get("created_at", None),
"file_path": result.get("file_path", "unknown_source"), "file_path": result.get("file_path", "unknown_source"),
"source_type": "vector", # Mark the source type
} }
valid_chunks.append(chunk_with_time) valid_chunks.append(chunk_with_metadata)
if not valid_chunks:
return [], [], []
# Apply reranking if enabled
global_config = chunks_vdb.global_config
valid_chunks = await apply_rerank_if_enabled(
query=query,
retrieved_docs=valid_chunks,
global_config=global_config,
top_k=query_param.top_k,
)
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug( logger.debug(
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" f"Vector search retrieved {len(valid_chunks)} chunks (top_k: {search_top_k})"
)
logger.info(
f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
) )
return valid_chunks
if not maybe_trun_chunks:
return [], [], []
# Create empty entities and relations contexts
entities_context = []
relations_context = []
# Create text_units_context directly as a list of dictionaries
text_units_context = []
for i, chunk in enumerate(maybe_trun_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk["file_path"],
}
)
return entities_context, relations_context, text_units_context
except Exception as e: except Exception as e:
logger.error(f"Error in _get_vector_context: {e}") logger.error(f"Error in _get_vector_context: {e}")
return [], [], [] return []
async def _build_query_context( async def _build_query_context(
query: str,
ll_keywords: str, ll_keywords: str,
hl_keywords: str, hl_keywords: str,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@ -1838,27 +1798,36 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode chunks_vdb: BaseVectorStorage = None,
): ):
logger.info(f"Process {os.getpid()} building query context...") logger.info(f"Process {os.getpid()} building query context...")
# Handle local and global modes as before # Collect all chunks from different sources
all_chunks = []
entities_context = []
relations_context = []
# Handle local and global modes
if query_param.mode == "local": if query_param.mode == "local":
entities_context, relations_context, text_units_context = await _get_node_data( entities_context, relations_context, entity_chunks = await _get_node_data(
ll_keywords, ll_keywords,
knowledge_graph_inst, knowledge_graph_inst,
entities_vdb, entities_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
all_chunks.extend(entity_chunks)
elif query_param.mode == "global": elif query_param.mode == "global":
entities_context, relations_context, text_units_context = await _get_edge_data( entities_context, relations_context, relationship_chunks = await _get_edge_data(
hl_keywords, hl_keywords,
knowledge_graph_inst, knowledge_graph_inst,
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
all_chunks.extend(relationship_chunks)
else: # hybrid or mix mode else: # hybrid or mix mode
ll_data = await _get_node_data( ll_data = await _get_node_data(
ll_keywords, ll_keywords,
@ -1875,61 +1844,58 @@ async def _build_query_context(
query_param, query_param,
) )
( (ll_entities_context, ll_relations_context, ll_chunks) = ll_data
ll_entities_context, (hl_entities_context, hl_relations_context, hl_chunks) = hl_data
ll_relations_context,
ll_text_units_context,
) = ll_data
( # Collect chunks from entity and relationship sources
hl_entities_context, all_chunks.extend(ll_chunks)
hl_relations_context, all_chunks.extend(hl_chunks)
hl_text_units_context,
) = hl_data
# Initialize vector data with empty lists # Get vector chunks if in mix mode
vector_entities_context, vector_relations_context, vector_text_units_context = ( if query_param.mode == "mix" and chunks_vdb:
[], vector_chunks = await _get_vector_context(
[], query,
[],
)
# Only get vector data if in mix mode
if query_param.mode == "mix" and hasattr(query_param, "original_query"):
# Get tokenizer from text_chunks_db
tokenizer = text_chunks_db.global_config.get("tokenizer")
# Get vector context in triple format
vector_data = await _get_vector_context(
query_param.original_query, # We need to pass the original query
chunks_vdb, chunks_vdb,
query_param, query_param,
tokenizer,
) )
all_chunks.extend(vector_chunks)
# If vector_data is not None, unpack it # Combine entities and relations contexts
if vector_data is not None:
(
vector_entities_context,
vector_relations_context,
vector_text_units_context,
) = vector_data
# Combine and deduplicate the entities, relationships, and sources
entities_context = process_combine_contexts( entities_context = process_combine_contexts(
hl_entities_context, ll_entities_context, vector_entities_context hl_entities_context, ll_entities_context
) )
relations_context = process_combine_contexts( relations_context = process_combine_contexts(
hl_relations_context, ll_relations_context, vector_relations_context hl_relations_context, ll_relations_context
) )
text_units_context = process_combine_contexts(
hl_text_units_context, ll_text_units_context, vector_text_units_context # Process all chunks uniformly: deduplication, reranking, and token truncation
processed_chunks = await process_chunks_unified(
query=query,
chunks=all_chunks,
query_param=query_param,
global_config=text_chunks_db.global_config,
source_type="mixed",
)
# Build final text_units_context from processed chunks
text_units_context = []
for i, chunk in enumerate(processed_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
}
) )
logger.info(
f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks"
)
# not necessary to use LLM to generate a response # not necessary to use LLM to generate a response
if not entities_context and not relations_context: if not entities_context and not relations_context:
return None return None
# 转换为 JSON 字符串
entities_str = json.dumps(entities_context, ensure_ascii=False) entities_str = json.dumps(entities_context, ensure_ascii=False)
relations_str = json.dumps(relations_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False)
text_units_str = json.dumps(text_units_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False)
@ -1975,15 +1941,6 @@ async def _get_node_data(
if not len(results): if not len(results):
return "", "", "" return "", "", ""
# Apply reranking if enabled for entity results
global_config = entities_vdb.global_config
results = await apply_rerank_if_enabled(
query=query,
retrieved_docs=results,
global_config=global_config,
top_k=query_param.top_k,
)
# Extract all entity IDs from your results list # Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results] node_ids = [r["entity_name"] for r in results]
@ -2085,16 +2042,7 @@ async def _get_node_data(
} }
) )
text_units_context = [] return entities_context, relations_context, use_text_units
for i, t in enumerate(use_text_units):
text_units_context.append(
{
"id": i + 1,
"content": t["content"],
"file_path": t.get("file_path", "unknown_source"),
}
)
return entities_context, relations_context, text_units_context
async def _find_most_related_text_unit_from_entities( async def _find_most_related_text_unit_from_entities(
@ -2183,23 +2131,21 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found") logger.warning("No valid text units found")
return [] return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") # Sort by relation counts and order, but don't truncate
all_text_units = sorted( all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
) )
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug( logger.debug(f"Found {len(all_text_units)} entity-related chunks")
f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
)
all_text_units = [t["data"] for t in all_text_units] # Add source type marking and return chunk data
return all_text_units result_chunks = []
for t in all_text_units:
chunk_data = t["data"].copy()
chunk_data["source_type"] = "entity"
result_chunks.append(chunk_data)
return result_chunks
async def _find_most_related_edges_from_entities( async def _find_most_related_edges_from_entities(
@ -2287,15 +2233,6 @@ async def _get_edge_data(
if not len(results): if not len(results):
return "", "", "" return "", "", ""
# Apply reranking if enabled for relationship results
global_config = relationships_vdb.global_config
results = await apply_rerank_if_enabled(
query=keywords,
retrieved_docs=results,
global_config=global_config,
top_k=query_param.top_k,
)
# Prepare edge pairs in two forms: # Prepare edge pairs in two forms:
# For the batch edge properties function, use dicts. # For the batch edge properties function, use dicts.
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
@ -2510,21 +2447,16 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering") logger.warning("No valid text chunks after filtering")
return [] return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") logger.debug(f"Found {len(valid_text_units)} relationship-related chunks")
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug( # Add source type marking and return chunk data
f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})" result_chunks = []
) for t in valid_text_units:
chunk_data = t["data"].copy()
chunk_data["source_type"] = "relationship"
result_chunks.append(chunk_data)
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] return result_chunks
return all_text_units
async def naive_query( async def naive_query(
@ -2552,12 +2484,30 @@ async def naive_query(
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
_, _, text_units_context = await _get_vector_context( chunks = await _get_vector_context(query, chunks_vdb, query_param)
query, chunks_vdb, query_param, tokenizer
if chunks is None or len(chunks) == 0:
return PROMPTS["fail_response"]
# Process chunks using unified processing
processed_chunks = await process_chunks_unified(
query=query,
chunks=chunks,
query_param=query_param,
global_config=global_config,
source_type="vector",
) )
if text_units_context is None or len(text_units_context) == 0: # Build text_units_context from processed chunks
return PROMPTS["fail_response"] text_units_context = []
for i, chunk in enumerate(processed_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
}
)
text_units_str = json.dumps(text_units_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False)
if query_param.only_need_context: if query_param.only_need_context:
@ -2683,6 +2633,7 @@ async def kg_query_with_keywords(
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
context = await _build_query_context( context = await _build_query_context(
query,
ll_keywords_str, ll_keywords_str,
hl_keywords_str, hl_keywords_str,
knowledge_graph_inst, knowledge_graph_inst,
@ -2805,8 +2756,6 @@ async def query_with_keywords(
f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}" f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
) )
param.original_query = query
# Use appropriate query method based on mode # Use appropriate query method based on mode
if param.mode in ["local", "global", "hybrid", "mix"]: if param.mode in ["local", "global", "hybrid", "mix"]:
return await kg_query_with_keywords( return await kg_query_with_keywords(
@ -2887,3 +2836,68 @@ async def apply_rerank_if_enabled(
except Exception as e: except Exception as e:
logger.error(f"Error during reranking: {e}, using original documents") logger.error(f"Error during reranking: {e}, using original documents")
return retrieved_docs return retrieved_docs
async def process_chunks_unified(
query: str,
chunks: list[dict],
query_param: QueryParam,
global_config: dict,
source_type: str = "mixed",
) -> list[dict]:
"""
Unified processing for text chunks: deduplication, reranking, and token truncation.
Args:
query: Search query for reranking
chunks: List of text chunks to process
query_param: Query parameters containing configuration
global_config: Global configuration dictionary
source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
Returns:
Processed and filtered list of text chunks
"""
if not chunks:
return []
# 1. Deduplication based on content
seen_content = set()
unique_chunks = []
for chunk in chunks:
content = chunk.get("content", "")
if content and content not in seen_content:
seen_content.add(content)
unique_chunks.append(chunk)
logger.debug(
f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
)
# 2. Apply reranking if enabled and query is provided
if global_config.get("enable_rerank", False) and query and unique_chunks:
rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
unique_chunks = await apply_rerank_if_enabled(
query=query,
retrieved_docs=unique_chunks,
global_config=global_config,
top_k=rerank_top_k,
)
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
# 3. Token-based final truncation
tokenizer = global_config.get("tokenizer")
if tokenizer and unique_chunks:
original_count = len(unique_chunks)
unique_chunks = truncate_list_by_token_size(
unique_chunks,
key=lambda x: x.get("content", ""),
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
f"(max tokens: {query_param.max_token_for_text_unit}, source: {source_type})"
)
return unique_chunks