From 04a57445da4d1e1c75776e392ebec89185755c30 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 8 Jul 2025 13:31:05 +0800 Subject: [PATCH] update chunks truncation method --- README-zh.md | 10 ++ README.md | 12 +- env.example | 4 +- lightrag/base.py | 27 ++-- lightrag/operate.py | 338 +++++++++++++++++++++++--------------------- 5 files changed, 211 insertions(+), 180 deletions(-) diff --git a/README-zh.md b/README-zh.md index 45335489..7dd7e975 100644 --- a/README-zh.md +++ b/README-zh.md @@ -294,6 +294,16 @@ class QueryParam: top_k: int = int(os.getenv("TOP_K", "60")) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5")) + """Number of text chunks to retrieve initially from vector search. + If None, defaults to top_k value. + """ + + chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5")) + """Number of text chunks to keep after reranking. + If None, keeps all chunks returned from initial retrieval. + """ + max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) """Maximum number of tokens allowed for each retrieved text chunk.""" diff --git a/README.md b/README.md index e812e8df..79479da9 100644 --- a/README.md +++ b/README.md @@ -153,7 +153,7 @@ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_d python examples/lightrag_openai_demo.py ``` -For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code’s LLM and embedding configurations accordingly. +For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code's LLM and embedding configurations accordingly. **Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory. @@ -300,6 +300,16 @@ class QueryParam: top_k: int = int(os.getenv("TOP_K", "60")) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5")) + """Number of text chunks to retrieve initially from vector search. + If None, defaults to top_k value. + """ + + chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5")) + """Number of text chunks to keep after reranking. + If None, keeps all chunks returned from initial retrieval. + """ + max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) """Maximum number of tokens allowed for each retrieved text chunk.""" diff --git a/env.example b/env.example index c4a09cad..e09494b8 100644 --- a/env.example +++ b/env.example @@ -46,7 +46,9 @@ OLLAMA_EMULATING_MODEL_TAG=latest # HISTORY_TURNS=3 # COSINE_THRESHOLD=0.2 # TOP_K=60 -# MAX_TOKEN_TEXT_CHUNK=4000 +# CHUNK_TOP_K=5 +# CHUNK_RERANK_TOP_K=5 +# MAX_TOKEN_TEXT_CHUNK=6000 # MAX_TOKEN_RELATION_DESC=4000 # MAX_TOKEN_ENTITY_DESC=4000 diff --git a/lightrag/base.py b/lightrag/base.py index 57cb2ac6..97564ac2 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -60,7 +60,17 @@ class QueryParam: top_k: int = int(os.getenv("TOP_K", "60")) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" - max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) + chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5")) + """Number of text chunks to retrieve initially from vector search. + If None, defaults to top_k value. + """ + + chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5")) + """Number of text chunks to keep after reranking. + If None, keeps all chunks returned from initial retrieval. + """ + + max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "6000")) """Maximum number of tokens allowed for each retrieved text chunk.""" max_token_for_global_context: int = int( @@ -280,21 +290,6 @@ class BaseKVStorage(StorageNameSpace, ABC): False: if the cache drop failed, or the cache mode is not supported """ - # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool: - # """Delete specific cache records from storage by chunk IDs - - # Importance notes for in-memory storage: - # 1. Changes will be persisted to disk during the next index_done_callback - # 2. update flags to notify other processes that data persistence is needed - - # Args: - # chunk_ids (list[str]): List of chunk IDs to be dropped from storage - - # Returns: - # True: if the cache drop successfully - # False: if the cache drop failed, or the operation is not supported - # """ - @dataclass class BaseGraphStorage(StorageNameSpace, ABC): diff --git a/lightrag/operate.py b/lightrag/operate.py index 645c1e85..f9f53285 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1526,6 +1526,7 @@ async def kg_query( # Build context context = await _build_query_context( + query, ll_keywords_str, hl_keywords_str, knowledge_graph_inst, @@ -1744,93 +1745,52 @@ async def _get_vector_context( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, - tokenizer: Tokenizer, -) -> tuple[list, list, list] | None: +) -> list[dict]: """ - Retrieve vector context from the vector database. + Retrieve text chunks from the vector database without reranking or truncation. - This function performs vector search to find relevant text chunks for a query, - formats them with file path and creation time information. + This function performs vector search to find relevant text chunks for a query. + Reranking and truncation will be handled later in the unified processing. Args: query: The query string to search for chunks_vdb: Vector database containing document chunks - query_param: Query parameters including top_k and ids - tokenizer: Tokenizer for counting tokens + query_param: Query parameters including chunk_top_k and ids Returns: - Tuple (empty_entities, empty_relations, text_units) for combine_contexts, - compatible with _get_edge_data and _get_node_data format + List of text chunks with metadata """ try: - results = await chunks_vdb.query( - query, top_k=query_param.top_k, ids=query_param.ids - ) + # Use chunk_top_k if specified, otherwise fall back to top_k + search_top_k = query_param.chunk_top_k or query_param.top_k + + results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids) if not results: - return [], [], [] + return [] valid_chunks = [] for result in results: if "content" in result: - # Directly use content from chunks_vdb.query result - chunk_with_time = { + chunk_with_metadata = { "content": result["content"], "created_at": result.get("created_at", None), "file_path": result.get("file_path", "unknown_source"), + "source_type": "vector", # Mark the source type } - valid_chunks.append(chunk_with_time) - - if not valid_chunks: - return [], [], [] - - # Apply reranking if enabled - global_config = chunks_vdb.global_config - valid_chunks = await apply_rerank_if_enabled( - query=query, - retrieved_docs=valid_chunks, - global_config=global_config, - top_k=query_param.top_k, - ) - - maybe_trun_chunks = truncate_list_by_token_size( - valid_chunks, - key=lambda x: x["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) + valid_chunks.append(chunk_with_metadata) logger.debug( - f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" - ) - logger.info( - f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" + f"Vector search retrieved {len(valid_chunks)} chunks (top_k: {search_top_k})" ) + return valid_chunks - if not maybe_trun_chunks: - return [], [], [] - - # Create empty entities and relations contexts - entities_context = [] - relations_context = [] - - # Create text_units_context directly as a list of dictionaries - text_units_context = [] - for i, chunk in enumerate(maybe_trun_chunks): - text_units_context.append( - { - "id": i + 1, - "content": chunk["content"], - "file_path": chunk["file_path"], - } - ) - - return entities_context, relations_context, text_units_context except Exception as e: logger.error(f"Error in _get_vector_context: {e}") - return [], [], [] + return [] async def _build_query_context( + query: str, ll_keywords: str, hl_keywords: str, knowledge_graph_inst: BaseGraphStorage, @@ -1838,27 +1798,36 @@ async def _build_query_context( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode + chunks_vdb: BaseVectorStorage = None, ): logger.info(f"Process {os.getpid()} building query context...") - # Handle local and global modes as before + # Collect all chunks from different sources + all_chunks = [] + entities_context = [] + relations_context = [] + + # Handle local and global modes if query_param.mode == "local": - entities_context, relations_context, text_units_context = await _get_node_data( + entities_context, relations_context, entity_chunks = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) + all_chunks.extend(entity_chunks) + elif query_param.mode == "global": - entities_context, relations_context, text_units_context = await _get_edge_data( + entities_context, relations_context, relationship_chunks = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, text_chunks_db, query_param, ) + all_chunks.extend(relationship_chunks) + else: # hybrid or mix mode ll_data = await _get_node_data( ll_keywords, @@ -1875,61 +1844,58 @@ async def _build_query_context( query_param, ) - ( - ll_entities_context, - ll_relations_context, - ll_text_units_context, - ) = ll_data + (ll_entities_context, ll_relations_context, ll_chunks) = ll_data + (hl_entities_context, hl_relations_context, hl_chunks) = hl_data - ( - hl_entities_context, - hl_relations_context, - hl_text_units_context, - ) = hl_data + # Collect chunks from entity and relationship sources + all_chunks.extend(ll_chunks) + all_chunks.extend(hl_chunks) - # Initialize vector data with empty lists - vector_entities_context, vector_relations_context, vector_text_units_context = ( - [], - [], - [], - ) - - # Only get vector data if in mix mode - if query_param.mode == "mix" and hasattr(query_param, "original_query"): - # Get tokenizer from text_chunks_db - tokenizer = text_chunks_db.global_config.get("tokenizer") - - # Get vector context in triple format - vector_data = await _get_vector_context( - query_param.original_query, # We need to pass the original query + # Get vector chunks if in mix mode + if query_param.mode == "mix" and chunks_vdb: + vector_chunks = await _get_vector_context( + query, chunks_vdb, query_param, - tokenizer, ) + all_chunks.extend(vector_chunks) - # If vector_data is not None, unpack it - if vector_data is not None: - ( - vector_entities_context, - vector_relations_context, - vector_text_units_context, - ) = vector_data - - # Combine and deduplicate the entities, relationships, and sources + # Combine entities and relations contexts entities_context = process_combine_contexts( - hl_entities_context, ll_entities_context, vector_entities_context + hl_entities_context, ll_entities_context ) relations_context = process_combine_contexts( - hl_relations_context, ll_relations_context, vector_relations_context + hl_relations_context, ll_relations_context ) - text_units_context = process_combine_contexts( - hl_text_units_context, ll_text_units_context, vector_text_units_context + + # Process all chunks uniformly: deduplication, reranking, and token truncation + processed_chunks = await process_chunks_unified( + query=query, + chunks=all_chunks, + query_param=query_param, + global_config=text_chunks_db.global_config, + source_type="mixed", + ) + + # Build final text_units_context from processed chunks + text_units_context = [] + for i, chunk in enumerate(processed_chunks): + text_units_context.append( + { + "id": i + 1, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } ) + + logger.info( + f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks" + ) + # not necessary to use LLM to generate a response if not entities_context and not relations_context: return None - # 转换为 JSON 字符串 entities_str = json.dumps(entities_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False) @@ -1975,15 +1941,6 @@ async def _get_node_data( if not len(results): return "", "", "" - # Apply reranking if enabled for entity results - global_config = entities_vdb.global_config - results = await apply_rerank_if_enabled( - query=query, - retrieved_docs=results, - global_config=global_config, - top_k=query_param.top_k, - ) - # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] @@ -2085,16 +2042,7 @@ async def _get_node_data( } ) - text_units_context = [] - for i, t in enumerate(use_text_units): - text_units_context.append( - { - "id": i + 1, - "content": t["content"], - "file_path": t.get("file_path", "unknown_source"), - } - ) - return entities_context, relations_context, text_units_context + return entities_context, relations_context, use_text_units async def _find_most_related_text_unit_from_entities( @@ -2183,23 +2131,21 @@ async def _find_most_related_text_unit_from_entities( logger.warning("No valid text units found") return [] - tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") + # Sort by relation counts and order, but don't truncate all_text_units = sorted( all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) ) - all_text_units = truncate_list_by_token_size( - all_text_units, - key=lambda x: x["data"]["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) - logger.debug( - f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})" - ) + logger.debug(f"Found {len(all_text_units)} entity-related chunks") - all_text_units = [t["data"] for t in all_text_units] - return all_text_units + # Add source type marking and return chunk data + result_chunks = [] + for t in all_text_units: + chunk_data = t["data"].copy() + chunk_data["source_type"] = "entity" + result_chunks.append(chunk_data) + + return result_chunks async def _find_most_related_edges_from_entities( @@ -2287,15 +2233,6 @@ async def _get_edge_data( if not len(results): return "", "", "" - # Apply reranking if enabled for relationship results - global_config = relationships_vdb.global_config - results = await apply_rerank_if_enabled( - query=keywords, - retrieved_docs=results, - global_config=global_config, - top_k=query_param.top_k, - ) - # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] @@ -2510,21 +2447,16 @@ async def _find_related_text_unit_from_relationships( logger.warning("No valid text chunks after filtering") return [] - tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") - truncated_text_units = truncate_list_by_token_size( - valid_text_units, - key=lambda x: x["data"]["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) + logger.debug(f"Found {len(valid_text_units)} relationship-related chunks") - logger.debug( - f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})" - ) + # Add source type marking and return chunk data + result_chunks = [] + for t in valid_text_units: + chunk_data = t["data"].copy() + chunk_data["source_type"] = "relationship" + result_chunks.append(chunk_data) - all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] - - return all_text_units + return result_chunks async def naive_query( @@ -2552,12 +2484,30 @@ async def naive_query( tokenizer: Tokenizer = global_config["tokenizer"] - _, _, text_units_context = await _get_vector_context( - query, chunks_vdb, query_param, tokenizer + chunks = await _get_vector_context(query, chunks_vdb, query_param) + + if chunks is None or len(chunks) == 0: + return PROMPTS["fail_response"] + + # Process chunks using unified processing + processed_chunks = await process_chunks_unified( + query=query, + chunks=chunks, + query_param=query_param, + global_config=global_config, + source_type="vector", ) - if text_units_context is None or len(text_units_context) == 0: - return PROMPTS["fail_response"] + # Build text_units_context from processed chunks + text_units_context = [] + for i, chunk in enumerate(processed_chunks): + text_units_context.append( + { + "id": i + 1, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } + ) text_units_str = json.dumps(text_units_context, ensure_ascii=False) if query_param.only_need_context: @@ -2683,6 +2633,7 @@ async def kg_query_with_keywords( hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" context = await _build_query_context( + query, ll_keywords_str, hl_keywords_str, knowledge_graph_inst, @@ -2805,8 +2756,6 @@ async def query_with_keywords( f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}" ) - param.original_query = query - # Use appropriate query method based on mode if param.mode in ["local", "global", "hybrid", "mix"]: return await kg_query_with_keywords( @@ -2887,3 +2836,68 @@ async def apply_rerank_if_enabled( except Exception as e: logger.error(f"Error during reranking: {e}, using original documents") return retrieved_docs + + +async def process_chunks_unified( + query: str, + chunks: list[dict], + query_param: QueryParam, + global_config: dict, + source_type: str = "mixed", +) -> list[dict]: + """ + Unified processing for text chunks: deduplication, reranking, and token truncation. + + Args: + query: Search query for reranking + chunks: List of text chunks to process + query_param: Query parameters containing configuration + global_config: Global configuration dictionary + source_type: Source type for logging ("vector", "entity", "relationship", "mixed") + + Returns: + Processed and filtered list of text chunks + """ + if not chunks: + return [] + + # 1. Deduplication based on content + seen_content = set() + unique_chunks = [] + for chunk in chunks: + content = chunk.get("content", "") + if content and content not in seen_content: + seen_content.add(content) + unique_chunks.append(chunk) + + logger.debug( + f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})" + ) + + # 2. Apply reranking if enabled and query is provided + if global_config.get("enable_rerank", False) and query and unique_chunks: + rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks) + unique_chunks = await apply_rerank_if_enabled( + query=query, + retrieved_docs=unique_chunks, + global_config=global_config, + top_k=rerank_top_k, + ) + logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})") + + # 3. Token-based final truncation + tokenizer = global_config.get("tokenizer") + if tokenizer and unique_chunks: + original_count = len(unique_chunks) + unique_chunks = truncate_list_by_token_size( + unique_chunks, + key=lambda x: x.get("content", ""), + max_token_size=query_param.max_token_for_text_unit, + tokenizer=tokenizer, + ) + logger.debug( + f"Token truncation: {len(unique_chunks)} chunks from {original_count} " + f"(max tokens: {query_param.max_token_for_text_unit}, source: {source_type})" + ) + + return unique_chunks