diff --git a/env.example b/env.example index d690fb00..57d03f8b 100644 --- a/env.example +++ b/env.example @@ -48,7 +48,7 @@ OLLAMA_EMULATING_MODEL_TAG=latest ######################## ### Query Configuration ######################## -# LLM responde cache for query (Not valid for streaming response +# LLM responde cache for query (Not valid for streaming response) ENABLE_LLM_CACHE=true # HISTORY_TURNS=0 # COSINE_THRESHOLD=0.2 @@ -62,8 +62,8 @@ ENABLE_LLM_CACHE=true # MAX_RELATION_TOKENS=10000 ### control the maximum tokens send to LLM (include entities, raltions and chunks) # MAX_TOTAL_TOKENS=32000 -### maxumium related chunks grab from single entity or relations -# RELATED_CHUNK_NUMBER=10 +### maximum number of related chunks per source entity or relation (higher values increase re-ranking time) +# RELATED_CHUNK_NUMBER=5 ### Reranker configuration (Set ENABLE_RERANK to true in reranking model is configed) ENABLE_RERANK=False diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py index 744d1e58..a5c46e29 100644 --- a/lightrag/api/__init__.py +++ b/lightrag/api/__init__.py @@ -1 +1 @@ -__api_version__ = "0188" +__api_version__ = "0189" diff --git a/lightrag/constants.py b/lightrag/constants.py index 8ce400be..26205689 100644 --- a/lightrag/constants.py +++ b/lightrag/constants.py @@ -21,7 +21,7 @@ DEFAULT_MAX_TOTAL_TOKENS = 32000 DEFAULT_HISTORY_TURNS = 0 DEFAULT_ENABLE_RERANK = True DEFAULT_COSINE_THRESHOLD = 0.2 -DEFAULT_RELATED_CHUNK_NUMBER = 10 +DEFAULT_RELATED_CHUNK_NUMBER = 5 # Separator for graph fields GRAPH_FIELD_SEP = "" diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 8a13fd21..ef73a206 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -536,7 +536,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): # Load the collection if it's not already loaded # In Milvus, collections need to be loaded before they can be searched self._client.load_collection(self.namespace) - logger.debug(f"Collection {self.namespace} loaded successfully") + # logger.debug(f"Collection {self.namespace} loaded successfully") except Exception as e: logger.error(f"Failed to load collection {self.namespace}: {e}") diff --git a/lightrag/operate.py b/lightrag/operate.py index e72ae3d7..a3075210 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -18,7 +18,6 @@ from .utils import ( pack_user_ass_to_openai_messages, split_string_by_multi_markers, truncate_list_by_token_size, - process_combine_contexts, compute_args_hash, handle_cache, save_to_cache, @@ -27,6 +26,8 @@ from .utils import ( use_llm_func_with_cache, update_chunk_cache_list, remove_think_tags, + linear_gradient_weighted_polling, + process_chunks_unified, ) from .base import ( BaseGraphStorage, @@ -1916,6 +1917,7 @@ async def _get_vector_context( "created_at": result.get("created_at", None), "file_path": result.get("file_path", "unknown_source"), "source_type": "vector", # Mark the source type + "chunk_id": result.get("id"), # Add chunk_id for deduplication } valid_chunks.append(chunk_with_metadata) @@ -1942,67 +1944,50 @@ async def _build_query_context( ): logger.info(f"Process {os.getpid()} building query context...") - # Collect all chunks from different sources - all_chunks = [] + # Collect chunks from different sources separately + vector_chunks = [] + entity_chunks = [] + relation_chunks = [] entities_context = [] relations_context = [] # Store original data for later text chunk retrieval - original_node_datas = [] - original_edge_datas = [] + local_entities = [] + local_relations = [] + global_entities = [] + global_relations = [] # Handle local and global modes if query_param.mode == "local": - ( - entities_context, - relations_context, - node_datas, - use_relations, - ) = await _get_node_data( + local_entities, local_relations = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, query_param, ) - original_node_datas = node_datas - original_edge_datas = use_relations elif query_param.mode == "global": - ( - entities_context, - relations_context, - edge_datas, - use_entities, - ) = await _get_edge_data( + global_relations, global_entities = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, query_param, ) - original_edge_datas = edge_datas - original_node_datas = use_entities else: # hybrid or mix mode - ll_data = await _get_node_data( + local_entities, local_relations = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, query_param, ) - hl_data = await _get_edge_data( + global_relations, global_entities = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, query_param, ) - (ll_entities_context, ll_relations_context, ll_node_datas, ll_edge_datas) = ( - ll_data - ) - (hl_entities_context, hl_relations_context, hl_edge_datas, hl_node_datas) = ( - hl_data - ) - # Get vector chunks first if in mix mode if query_param.mode == "mix" and chunks_vdb: vector_chunks = await _get_vector_context( @@ -2010,22 +1995,117 @@ async def _build_query_context( chunks_vdb, query_param, ) - all_chunks.extend(vector_chunks) - # Store original data from both sources - original_node_datas = ll_node_datas + hl_node_datas - original_edge_datas = ll_edge_datas + hl_edge_datas + # Use round-robin merge to combine local and global data fairly + final_entities = [] + seen_entities = set() - # Combine entities and relations contexts - entities_context = process_combine_contexts( - ll_entities_context, hl_entities_context - ) - relations_context = process_combine_contexts( - hl_relations_context, ll_relations_context + # Round-robin merge entities + max_len = max(len(local_entities), len(global_entities)) + for i in range(max_len): + # First from local + if i < len(local_entities): + entity = local_entities[i] + entity_name = entity.get("entity_name") + if entity_name and entity_name not in seen_entities: + final_entities.append(entity) + seen_entities.add(entity_name) + + # Then from global + if i < len(global_entities): + entity = global_entities[i] + entity_name = entity.get("entity_name") + if entity_name and entity_name not in seen_entities: + final_entities.append(entity) + seen_entities.add(entity_name) + + # Round-robin merge relations + final_relations = [] + seen_relations = set() + + max_len = max(len(local_relations), len(global_relations)) + for i in range(max_len): + # First from local + if i < len(local_relations): + relation = local_relations[i] + # Build relation unique identifier + if "src_tgt" in relation: + rel_key = tuple(sorted(relation["src_tgt"])) + else: + rel_key = tuple( + sorted([relation.get("src_id"), relation.get("tgt_id")]) + ) + + if rel_key not in seen_relations: + final_relations.append(relation) + seen_relations.add(rel_key) + + # Then from global + if i < len(global_relations): + relation = global_relations[i] + # Build relation unique identifier + if "src_tgt" in relation: + rel_key = tuple(sorted(relation["src_tgt"])) + else: + rel_key = tuple( + sorted([relation.get("src_id"), relation.get("tgt_id")]) + ) + + if rel_key not in seen_relations: + final_relations.append(relation) + seen_relations.add(rel_key) + + # Generate entities context + entities_context = [] + for i, n in enumerate(final_entities): + created_at = n.get("created_at", "UNKNOWN") + if isinstance(created_at, (int, float)): + created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) + + # Get file path from node data + file_path = n.get("file_path", "unknown_source") + + entities_context.append( + { + "id": i + 1, + "entity": n["entity_name"], + "type": n.get("entity_type", "UNKNOWN"), + "description": n.get("description", "UNKNOWN"), + "created_at": created_at, + "file_path": file_path, + } ) - logger.info( - f"Initial context: {len(entities_context)} entities, {len(relations_context)} relations, {len(all_chunks)} chunks" + # Generate relations context + relations_context = [] + for i, e in enumerate(final_relations): + created_at = e.get("created_at", "UNKNOWN") + # Convert timestamp to readable format + if isinstance(created_at, (int, float)): + created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) + + # Get file path from edge data + file_path = e.get("file_path", "unknown_source") + + # Handle different relation data formats + if "src_tgt" in e: + entity1, entity2 = e["src_tgt"] + else: + entity1, entity2 = e.get("src_id"), e.get("tgt_id") + + relations_context.append( + { + "id": i + 1, + "entity1": entity1, + "entity2": entity2, + "description": e.get("description", "UNKNOWN"), + "created_at": created_at, + "file_path": file_path, + } + ) + + logger.debug( + f"Initial KG query results: {len(entities_context)} entities, {len(relations_context)} relations" ) # Unified token control system - Apply precise token limits to entities and relations @@ -2053,8 +2133,6 @@ async def _build_query_context( # Truncate entities based on complete JSON serialization if entities_context: - original_entity_count = len(entities_context) - # Process entities context to replace GRAPH_FIELD_SEP with : in file_path fields for entity in entities_context: if "file_path" in entity and entity["file_path"]: @@ -2066,15 +2144,9 @@ async def _build_query_context( max_token_size=max_entity_tokens, tokenizer=tokenizer, ) - if len(entities_context) < original_entity_count: - logger.debug( - f"Truncated entities: {original_entity_count} -> {len(entities_context)} (entity max tokens: {max_entity_tokens})" - ) # Truncate relations based on complete JSON serialization if relations_context: - original_relation_count = len(relations_context) - # Process relations context to replace GRAPH_FIELD_SEP with : in file_path fields for relation in relations_context: if "file_path" in relation and relation["file_path"]: @@ -2088,30 +2160,28 @@ async def _build_query_context( max_token_size=max_relation_tokens, tokenizer=tokenizer, ) - if len(relations_context) < original_relation_count: - logger.debug( - f"Truncated relations: {original_relation_count} -> {len(relations_context)} (relation max tokens: {max_relation_tokens})" - ) # After truncation, get text chunks based on final entities and relations - logger.info("Getting text chunks based on truncated entities and relations...") + logger.info( + f"Truncated KG query results: {len(entities_context)} entities, {len(relations_context)} relations" + ) # Create filtered data based on truncated context final_node_datas = [] - if entities_context and original_node_datas: + if entities_context and final_entities: final_entity_names = {e["entity"] for e in entities_context} seen_nodes = set() - for node in original_node_datas: + for node in final_entities: name = node.get("entity_name") if name in final_entity_names and name not in seen_nodes: final_node_datas.append(node) seen_nodes.add(name) final_edge_datas = [] - if relations_context and original_edge_datas: + if relations_context and final_relations: final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context} seen_edges = set() - for edge in original_edge_datas: + for edge in final_relations: src, tgt = edge.get("src_id"), edge.get("tgt_id") if src is None or tgt is None: src, tgt = edge.get("src_tgt", (None, None)) @@ -2122,37 +2192,75 @@ async def _build_query_context( seen_edges.add(pair) # Get text chunks based on final filtered data - text_chunk_tasks = [] - if final_node_datas: - text_chunk_tasks.append( - _find_most_related_text_unit_from_entities( - final_node_datas, - query_param, - text_chunks_db, - knowledge_graph_inst, - ) + entity_chunks = await _find_most_related_text_unit_from_entities( + final_node_datas, + query_param, + text_chunks_db, + knowledge_graph_inst, ) if final_edge_datas: - text_chunk_tasks.append( - _find_related_text_unit_from_relationships( - final_edge_datas, - query_param, - text_chunks_db, - ) + relation_chunks = await _find_related_text_unit_from_relationships( + final_edge_datas, + query_param, + text_chunks_db, + entity_chunks, ) - # Execute text chunk retrieval in parallel - if text_chunk_tasks: - text_chunk_results = await asyncio.gather(*text_chunk_tasks) - for chunks in text_chunk_results: - if chunks: - all_chunks.extend(chunks) + # Round-robin merge chunks from different sources with deduplication by chunk_id + merged_chunks = [] + seen_chunk_ids = set() + max_len = max(len(vector_chunks), len(entity_chunks), len(relation_chunks)) + origin_len = len(vector_chunks) + len(entity_chunks) + len(relation_chunks) - # Apply token processing to chunks + for i in range(max_len): + # Add from vector chunks first (Naive mode) + if i < len(vector_chunks): + chunk = vector_chunks[i] + chunk_id = chunk.get("chunk_id") or chunk.get("id") + if chunk_id and chunk_id not in seen_chunk_ids: + seen_chunk_ids.add(chunk_id) + merged_chunks.append( + { + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } + ) + + # Add from entity chunks (Local mode) + if i < len(entity_chunks): + chunk = entity_chunks[i] + chunk_id = chunk.get("chunk_id") or chunk.get("id") + if chunk_id and chunk_id not in seen_chunk_ids: + seen_chunk_ids.add(chunk_id) + merged_chunks.append( + { + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } + ) + + # Add from relation chunks (Global mode) + if i < len(relation_chunks): + chunk = relation_chunks[i] + chunk_id = chunk.get("chunk_id") or chunk.get("id") + if chunk_id and chunk_id not in seen_chunk_ids: + seen_chunk_ids.add(chunk_id) + merged_chunks.append( + { + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } + ) + + logger.debug( + f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}" + ) + + # Apply token processing to merged chunks text_units_context = [] - if all_chunks: + if merged_chunks: # Calculate dynamic token limit for text chunks entities_str = json.dumps(entities_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False) @@ -2229,37 +2337,29 @@ async def _build_query_context( f"Token allocation - Total: {max_total_tokens}, History: {history_tokens}, SysPrompt: {sys_prompt_overhead}, KG: {kg_context_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}" ) - # Re-process chunks with dynamic token limit - if all_chunks: - # Create a temporary query_param copy with adjusted chunk token limit - temp_chunks = [ - {"content": chunk["content"], "file_path": chunk["file_path"]} - for chunk in all_chunks - ] + # Apply token truncation to chunks using the dynamic limit + truncated_chunks = await process_chunks_unified( + query=query, + unique_chunks=merged_chunks, + query_param=query_param, + global_config=text_chunks_db.global_config, + source_type=query_param.mode, + chunk_token_limit=available_chunk_tokens, # Pass dynamic limit + ) - # Apply token truncation to chunks using the dynamic limit - truncated_chunks = await process_chunks_unified( - query=query, - chunks=temp_chunks, - query_param=query_param, - global_config=text_chunks_db.global_config, - source_type="mixed", - chunk_token_limit=available_chunk_tokens, # Pass dynamic limit + # Rebuild text_units_context with truncated chunks + for i, chunk in enumerate(truncated_chunks): + text_units_context.append( + { + "id": i + 1, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } ) - # Rebuild text_units_context with truncated chunks - for i, chunk in enumerate(truncated_chunks): - text_units_context.append( - { - "id": i + 1, - "content": chunk["content"], - "file_path": chunk.get("file_path", "unknown_source"), - } - ) - - logger.debug( - f"Re-truncated chunks for dynamic token limit: {len(temp_chunks)} -> {len(text_units_context)} (chunk available tokens: {available_chunk_tokens})" - ) + logger.debug( + f"Final chunk processing: {len(merged_chunks)} -> {len(text_units_context)} (chunk available tokens: {available_chunk_tokens})" + ) logger.info( f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks" @@ -2311,7 +2411,7 @@ async def _get_node_data( ) if not len(results): - return "", "", [], [] + return [], [] # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] @@ -2350,49 +2450,9 @@ async def _get_node_data( f"Local query: {len(node_datas)} entites, {len(use_relations)} relations" ) - # build prompt - entities_context = [] - for i, n in enumerate(node_datas): - created_at = n.get("created_at", "UNKNOWN") - if isinstance(created_at, (int, float)): - created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) - - # Get file path from node data - file_path = n.get("file_path", "unknown_source") - - entities_context.append( - { - "id": i + 1, - "entity": n["entity_name"], - "type": n.get("entity_type", "UNKNOWN"), - "description": n.get("description", "UNKNOWN"), - "created_at": created_at, - "file_path": file_path, - } - ) - - relations_context = [] - for i, e in enumerate(use_relations): - created_at = e.get("created_at", "UNKNOWN") - # Convert timestamp to readable format - if isinstance(created_at, (int, float)): - created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) - - # Get file path from edge data - file_path = e.get("file_path", "unknown_source") - - relations_context.append( - { - "id": i + 1, - "entity1": e["src_tgt"][0], - "entity2": e["src_tgt"][1], - "description": e["description"], - "created_at": created_at, - "file_path": file_path, - } - ) - - return entities_context, relations_context, node_datas, use_relations + # Entities are sorted by cosine similarity + # Relations are sorted by rank + weight + return node_datas, use_relations async def _find_most_related_text_unit_from_entities( @@ -2401,105 +2461,94 @@ async def _find_most_related_text_unit_from_entities( text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, ): + """ + Find text chunks related to entities using linear gradient weighted polling algorithm. + + This function implements the optimized text chunk selection strategy: + 1. Sort text chunks for each entity by occurrence count in other entities + 2. Use linear gradient weighted polling to select chunks fairly + """ logger.debug(f"Searching text chunks for {len(node_datas)} entities") - text_units = [ - split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])[ - : text_chunks_db.global_config.get( - "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER - ) - ] - for dp in node_datas - if dp["source_id"] is not None - ] - - node_names = [dp["entity_name"] for dp in node_datas] - batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) - # Build the edges list in the same order as node_datas. - edges = [batch_edges_dict.get(name, []) for name in node_names] - - all_one_hop_nodes = set() - for this_edges in edges: - if not this_edges: - continue - all_one_hop_nodes.update([e[1] for e in this_edges]) - - all_one_hop_nodes = list(all_one_hop_nodes) - - # Batch retrieve one-hop node data using get_nodes_batch - all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch( - all_one_hop_nodes - ) - all_one_hop_nodes_data = [ - all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes - ] - - # Add null check for node data - all_one_hop_text_units_lookup = { - k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP])) - for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data) - if v is not None and "source_id" in v # Add source_id check - } - - all_text_units_lookup = {} - tasks = [] - - for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)): - for c_id in this_text_units: - if c_id not in all_text_units_lookup: - all_text_units_lookup[c_id] = index - tasks.append((c_id, index, this_edges)) - - # Process in batches tasks at a time to avoid overwhelming resources - batch_size = 5 - results = [] - - for i in range(0, len(tasks), batch_size): - batch_tasks = tasks[i : i + batch_size] - batch_results = await asyncio.gather( - *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in batch_tasks] - ) - results.extend(batch_results) - - for (c_id, index, this_edges), data in zip(tasks, results): - all_text_units_lookup[c_id] = { - "data": data, - "order": index, - "relation_counts": 0, - } - - if this_edges: - for e in this_edges: - if ( - e[1] in all_one_hop_text_units_lookup - and c_id in all_one_hop_text_units_lookup[e[1]] - ): - all_text_units_lookup[c_id]["relation_counts"] += 1 - - # Filter out None values and ensure data has content - all_text_units = [ - {"id": k, **v} - for k, v in all_text_units_lookup.items() - if v is not None and v.get("data") is not None and "content" in v["data"] - ] - - if not all_text_units: - logger.warning("No valid text units found") + if not node_datas: return [] - # Sort by relation counts and order, but don't truncate - all_text_units = sorted( - all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) + # Step 1: Collect all text chunks for each entity + entities_with_chunks = [] + for entity in node_datas: + if entity.get("source_id"): + chunks = split_string_by_multi_markers( + entity["source_id"], [GRAPH_FIELD_SEP] + ) + if chunks: + entities_with_chunks.append( + { + "entity_name": entity["entity_name"], + "chunks": chunks, + "entity_data": entity, + } + ) + + if not entities_with_chunks: + logger.warning("No entities with text chunks found") + return [] + + # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities) + chunk_occurrence_count = {} + for entity_info in entities_with_chunks: + deduplicated_chunks = [] + for chunk_id in entity_info["chunks"]: + chunk_occurrence_count[chunk_id] = ( + chunk_occurrence_count.get(chunk_id, 0) + 1 + ) + + # If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position) + if chunk_occurrence_count[chunk_id] == 1: + deduplicated_chunks.append(chunk_id) + # count > 1 means this chunk appeared in an earlier entity, so skip it + + # Update entity's chunks to deduplicated chunks + entity_info["chunks"] = deduplicated_chunks + + # Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority) + for entity_info in entities_with_chunks: + sorted_chunks = sorted( + entity_info["chunks"], + key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), + reverse=True, + ) + entity_info["sorted_chunks"] = sorted_chunks + + # Step 4: Apply linear gradient weighted polling algorithm + max_related_chunks = text_chunks_db.global_config.get( + "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER ) - logger.debug(f"Found {len(all_text_units)} entity-related chunks") + selected_chunk_ids = linear_gradient_weighted_polling( + entities_with_chunks, max_related_chunks, min_related_chunks=1 + ) - # Add source type marking and return chunk data + logger.debug( + f"Found {len(selected_chunk_ids)} entity-related chunks using linear gradient weighted polling" + ) + + if not selected_chunk_ids: + return [] + + # Step 5: Batch retrieve chunk data + unique_chunk_ids = list( + dict.fromkeys(selected_chunk_ids) + ) # Remove duplicates while preserving order + chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) + + # Step 6: Build result chunks with valid data result_chunks = [] - for t in all_text_units: - chunk_data = t["data"].copy() - chunk_data["source_type"] = "entity" - result_chunks.append(chunk_data) + for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list): + if chunk_data is not None and "content" in chunk_data: + chunk_data_copy = chunk_data.copy() + chunk_data_copy["source_type"] = "entity" + chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication + result_chunks.append(chunk_data_copy) return result_chunks @@ -2575,19 +2624,12 @@ async def _get_edge_data( ) if not len(results): - return "", "", [], [] + return [], [] # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] - # For edge degrees, use tuples. - edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results] - - # Call the batched functions concurrently. - edge_data_dict, edge_degrees_dict = await asyncio.gather( - knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), - knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples), - ) + edge_data_dict = await knowledge_graph_inst.get_edges_batch(edge_pairs_dicts) # Reconstruct edge_datas list in the same order as results. edge_datas = [] @@ -2601,19 +2643,16 @@ async def _get_edge_data( ) edge_props["weight"] = 1.0 - # Use edge degree from the batch as rank. + # Keep edge data without rank, maintain vector search order combined = { "src_id": k["src_id"], "tgt_id": k["tgt_id"], - "rank": edge_degrees_dict.get(pair, 0), "created_at": k.get("created_at", None), **edge_props, } edge_datas.append(combined) - edge_datas = sorted( - edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True - ) + # Relations maintain vector search order (sorted by similarity) use_entities = await _find_most_related_entities_from_relationships( edge_datas, @@ -2625,50 +2664,7 @@ async def _get_edge_data( f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations" ) - relations_context = [] - for i, e in enumerate(edge_datas): - created_at = e.get("created_at", "UNKNOWN") - # Convert timestamp to readable format - if isinstance(created_at, (int, float)): - created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) - - # Get file path from edge data - file_path = e.get("file_path", "unknown_source") - - relations_context.append( - { - "id": i + 1, - "entity1": e["src_id"], - "entity2": e["tgt_id"], - "description": e["description"], - "created_at": created_at, - "file_path": file_path, - } - ) - - entities_context = [] - for i, n in enumerate(use_entities): - created_at = n.get("created_at", "UNKNOWN") - # Convert timestamp to readable format - if isinstance(created_at, (int, float)): - created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) - - # Get file path from node data - file_path = n.get("file_path", "unknown_source") - - entities_context.append( - { - "id": i + 1, - "entity": n["entity_name"], - "type": n.get("entity_type", "UNKNOWN"), - "description": n.get("description", "UNKNOWN"), - "created_at": created_at, - "file_path": file_path, - } - ) - - # Return original data for later text chunk retrieval - return entities_context, relations_context, edge_datas, use_entities + return edge_datas, use_entities async def _find_most_related_entities_from_relationships( @@ -2687,22 +2683,18 @@ async def _find_most_related_entities_from_relationships( entity_names.append(e["tgt_id"]) seen.add(e["tgt_id"]) - # Batch approach: Retrieve nodes and their degrees concurrently with one query each. - nodes_dict, degrees_dict = await asyncio.gather( - knowledge_graph_inst.get_nodes_batch(entity_names), - knowledge_graph_inst.node_degrees_batch(entity_names), - ) + # Only get nodes data, no need for node degrees + nodes_dict = await knowledge_graph_inst.get_nodes_batch(entity_names) # Rebuild the list in the same order as entity_names node_datas = [] for entity_name in entity_names: node = nodes_dict.get(entity_name) - degree = degrees_dict.get(entity_name, 0) if node is None: logger.warning(f"Node '{entity_name}' not found in batch retrieval.") continue - # Combine the node data with the entity name and computed degree (as rank) - combined = {**node, "entity_name": entity_name, "rank": degree} + # Combine the node data with the entity name, no rank needed + combined = {**node, "entity_name": entity_name} node_datas.append(combined) return node_datas @@ -2712,71 +2704,132 @@ async def _find_related_text_unit_from_relationships( edge_datas: list[dict], query_param: QueryParam, text_chunks_db: BaseKVStorage, + entity_chunks: list[dict] = None, ): + """ + Find text chunks related to relationships using linear gradient weighted polling algorithm. + + This function implements the optimized text chunk selection strategy: + 1. Sort text chunks for each relationship by occurrence count in other relationships + 2. Use linear gradient weighted polling to select chunks fairly + """ logger.debug(f"Searching text chunks for {len(edge_datas)} relationships") - text_units = [ - split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])[ - : text_chunks_db.global_config.get( - "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER + if not edge_datas: + return [] + + # Step 1: Collect all text chunks for each relationship + relations_with_chunks = [] + for relation in edge_datas: + if relation.get("source_id"): + chunks = split_string_by_multi_markers( + relation["source_id"], [GRAPH_FIELD_SEP] ) + if chunks: + # Build relation identifier + if "src_tgt" in relation: + rel_key = tuple(sorted(relation["src_tgt"])) + else: + rel_key = tuple( + sorted([relation.get("src_id"), relation.get("tgt_id")]) + ) + + relations_with_chunks.append( + { + "relation_key": rel_key, + "chunks": chunks, + "relation_data": relation, + } + ) + + if not relations_with_chunks: + logger.warning("No relationships with text chunks found") + return [] + + # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships) + chunk_occurrence_count = {} + for relation_info in relations_with_chunks: + deduplicated_chunks = [] + for chunk_id in relation_info["chunks"]: + chunk_occurrence_count[chunk_id] = ( + chunk_occurrence_count.get(chunk_id, 0) + 1 + ) + + # If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position) + if chunk_occurrence_count[chunk_id] == 1: + deduplicated_chunks.append(chunk_id) + # count > 1 means this chunk appeared in an earlier relationship, so skip it + + # Update relationship's chunks to deduplicated chunks + relation_info["chunks"] = deduplicated_chunks + + # Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority) + for relation_info in relations_with_chunks: + sorted_chunks = sorted( + relation_info["chunks"], + key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), + reverse=True, + ) + relation_info["sorted_chunks"] = sorted_chunks + + # Step 4: Apply linear gradient weighted polling algorithm + max_related_chunks = text_chunks_db.global_config.get( + "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER + ) + + selected_chunk_ids = linear_gradient_weighted_polling( + relations_with_chunks, max_related_chunks, min_related_chunks=1 + ) + + logger.debug( + f"Found {len(selected_chunk_ids)} relationship-related chunks using linear gradient weighted polling" + ) + logger.info( + f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations" + ) + + if not selected_chunk_ids: + return [] + + # Step 4.5: Remove duplicates with entity_chunks before batch retrieval + if entity_chunks: + # Extract chunk IDs from entity_chunks + entity_chunk_ids = set() + for chunk in entity_chunks: + chunk_id = chunk.get("chunk_id") + if chunk_id: + entity_chunk_ids.add(chunk_id) + + # Filter out duplicate chunk IDs + original_count = len(selected_chunk_ids) + selected_chunk_ids = [ + chunk_id + for chunk_id in selected_chunk_ids + if chunk_id not in entity_chunk_ids ] - for dp in edge_datas - if dp["source_id"] is not None - ] - all_text_units_lookup = {} - # Deduplicate and preserve order | {c_id:order} - text_units_unique_flat = {} - for index, unit_list in enumerate(text_units): - for c_id in unit_list: - if ( - c_id not in text_units_unique_flat - or index < text_units_unique_flat[c_id] - ): - # Keep the smallest order - text_units_unique_flat[c_id] = index + logger.debug( + f"Deduplication relation-chunks with entity-chunks: {original_count} -> {len(selected_chunk_ids)} chunks " + ) - if not text_units_unique_flat: - logger.warning("No valid text chunks found") - return [] + # Early return if no chunks remain after deduplication + if not selected_chunk_ids: + return [] - # Batch get all text chunk data - chunk_ids = list(text_units_unique_flat.keys()) - chunk_data_list = await text_chunks_db.get_by_ids(chunk_ids) + # Step 5: Batch retrieve chunk data + unique_chunk_ids = list( + dict.fromkeys(selected_chunk_ids) + ) # Remove duplicates while preserving order + chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) - # Build lookup table, handling possible missing data - for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list): - if chunk_data is not None and "content" in chunk_data: - all_text_units_lookup[chunk_id] = { - "data": chunk_data, - "order": text_units_unique_flat[chunk_id], - } - - if not all_text_units_lookup: - logger.warning("No valid text chunks found") - return [] - - all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()] - all_text_units = sorted(all_text_units, key=lambda x: x["order"]) - - # Ensure all text chunks have content - valid_text_units = [ - t for t in all_text_units if t["data"] is not None and "content" in t["data"] - ] - - if not valid_text_units: - logger.warning("No valid text chunks after filtering") - return [] - - logger.debug(f"Found {len(valid_text_units)} relationship-related chunks") - - # Add source type marking and return chunk data + # Step 6: Build result chunks with valid data result_chunks = [] - for t in valid_text_units: - chunk_data = t["data"].copy() - chunk_data["source_type"] = "relationship" - result_chunks.append(chunk_data) + for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list): + if chunk_data is not None and "content" in chunk_data: + chunk_data_copy = chunk_data.copy() + chunk_data_copy["source_type"] = "relationship" + chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication + result_chunks.append(chunk_data_copy) return result_chunks @@ -2866,7 +2919,7 @@ async def naive_query( # Process chunks using unified processing with dynamic token limit processed_chunks = await process_chunks_unified( query=query, - chunks=chunks, + unique_chunks=chunks, query_param=query_param, global_config=global_config, source_type="vector", @@ -3163,145 +3216,3 @@ async def query_with_keywords( ) else: raise ValueError(f"Unknown mode {param.mode}") - - -async def apply_rerank_if_enabled( - query: str, - retrieved_docs: list[dict], - global_config: dict, - enable_rerank: bool = True, - top_n: int = None, -) -> list[dict]: - """ - Apply reranking to retrieved documents if rerank is enabled. - - Args: - query: The search query - retrieved_docs: List of retrieved documents - global_config: Global configuration containing rerank settings - enable_rerank: Whether to enable reranking from query parameter - top_n: Number of top documents to return after reranking - - Returns: - Reranked documents if rerank is enabled, otherwise original documents - """ - if not enable_rerank or not retrieved_docs: - return retrieved_docs - - rerank_func = global_config.get("rerank_model_func") - if not rerank_func: - logger.warning( - "Rerank is enabled but no rerank model is configured. Please set up a rerank model or set enable_rerank=False in query parameters." - ) - return retrieved_docs - - try: - logger.debug( - f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_n}" - ) - - # Apply reranking - let rerank_model_func handle top_k internally - reranked_docs = await rerank_func( - query=query, - documents=retrieved_docs, - top_n=top_n, - ) - if reranked_docs and len(reranked_docs) > 0: - if len(reranked_docs) > top_n: - reranked_docs = reranked_docs[:top_n] - logger.info( - f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}" - ) - return reranked_docs - else: - logger.warning("Rerank returned empty results, using original documents") - return retrieved_docs - - except Exception as e: - logger.error(f"Error during reranking: {e}, using original documents") - return retrieved_docs - - -async def process_chunks_unified( - query: str, - chunks: list[dict], - query_param: QueryParam, - global_config: dict, - source_type: str = "mixed", - chunk_token_limit: int = None, # Add parameter for dynamic token limit -) -> list[dict]: - """ - Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation. - - Args: - query: Search query for reranking - chunks: List of text chunks to process - query_param: Query parameters containing configuration - global_config: Global configuration dictionary - source_type: Source type for logging ("vector", "entity", "relationship", "mixed") - chunk_token_limit: Dynamic token limit for chunks (if None, uses default) - - Returns: - Processed and filtered list of text chunks - """ - if not chunks: - return [] - - # 1. Deduplication based on content - seen_content = set() - unique_chunks = [] - for chunk in chunks: - content = chunk.get("content", "") - if content and content not in seen_content: - seen_content.add(content) - unique_chunks.append(chunk) - - logger.debug( - f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})" - ) - - # 2. Apply reranking if enabled and query is provided - if query_param.enable_rerank and query and unique_chunks: - rerank_top_k = query_param.chunk_top_k or len(unique_chunks) - unique_chunks = await apply_rerank_if_enabled( - query=query, - retrieved_docs=unique_chunks, - global_config=global_config, - enable_rerank=query_param.enable_rerank, - top_n=rerank_top_k, - ) - logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})") - - # 3. Apply chunk_top_k limiting if specified - if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: - if len(unique_chunks) > query_param.chunk_top_k: - unique_chunks = unique_chunks[: query_param.chunk_top_k] - logger.debug( - f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})" - ) - - # 4. Token-based final truncation - tokenizer = global_config.get("tokenizer") - if tokenizer and unique_chunks: - # Set default chunk_token_limit if not provided - if chunk_token_limit is None: - # Get default from query_param or global_config - chunk_token_limit = getattr( - query_param, - "max_total_tokens", - global_config.get("MAX_TOTAL_TOKENS", 32000), - ) - - original_count = len(unique_chunks) - unique_chunks = truncate_list_by_token_size( - unique_chunks, - key=lambda x: x.get("content", ""), - max_token_size=chunk_token_limit, - tokenizer=tokenizer, - ) - logger.debug( - f"Token truncation: {len(unique_chunks)} chunks from {original_count} " - f"(chunk available tokens: {chunk_token_limit}, source: {source_type})" - ) - - return unique_chunks diff --git a/lightrag/utils.py b/lightrag/utils.py index 171cf9f6..bd5aeab2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -55,7 +55,7 @@ def get_env_value( # Use TYPE_CHECKING to avoid circular imports if TYPE_CHECKING: - from lightrag.base import BaseKVStorage + from lightrag.base import BaseKVStorage, QueryParam # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance @@ -85,8 +85,10 @@ def verbose_debug(msg: str, *args, **kwargs): formatted_msg = msg # Then truncate the formatted message truncated_msg = ( - formatted_msg[:100] + "..." if len(formatted_msg) > 100 else formatted_msg + formatted_msg[:150] + "..." if len(formatted_msg) > 150 else formatted_msg ) + # Remove consecutive newlines + truncated_msg = re.sub(r"\n+", "\n", truncated_msg) logger.debug(truncated_msg, **kwargs) @@ -777,39 +779,6 @@ def truncate_list_by_token_size( return list_data -def process_combine_contexts(*context_lists): - """ - Combine multiple context lists and remove duplicate content - - Args: - *context_lists: Any number of context lists - - Returns: - Combined context list with duplicates removed - """ - seen_content = {} - combined_data = [] - - # Iterate through all input context lists - for context_list in context_lists: - if not context_list: # Skip empty lists - continue - for item in context_list: - content_dict = { - k: v for k, v in item.items() if k != "id" and k != "created_at" - } - content_key = tuple(sorted(content_dict.items())) - if content_key not in seen_content: - seen_content[content_key] = item - combined_data.append(item) - - # Reassign IDs - for i, item in enumerate(combined_data): - item["id"] = str(i + 1) - - return combined_data - - def cosine_similarity(v1, v2): """Calculate cosine similarity between two vectors""" dot_product = np.dot(v1, v2) @@ -1673,6 +1642,86 @@ def check_storage_env_vars(storage_name: str) -> None: ) +def linear_gradient_weighted_polling( + entities_or_relations: list[dict], + max_related_chunks: int, + min_related_chunks: int = 1, +) -> list[str]: + """ + Linear gradient weighted polling algorithm for text chunk selection. + + This algorithm ensures that entities/relations with higher importance get more text chunks, + forming a linear decreasing allocation pattern. + + Args: + entities_or_relations: List of entities or relations sorted by importance (high to low) + max_related_chunks: Expected number of text chunks for the highest importance entity/relation + min_related_chunks: Expected number of text chunks for the lowest importance entity/relation + + Returns: + List of selected text chunk IDs + """ + if not entities_or_relations: + return [] + + n = len(entities_or_relations) + if n == 1: + # Only one entity/relation, return its first max_related_chunks text chunks + entity_chunks = entities_or_relations[0].get("sorted_chunks", []) + return entity_chunks[:max_related_chunks] + + # Calculate expected text chunk count for each position (linear decrease) + expected_counts = [] + for i in range(n): + # Linear interpolation: from max_related_chunks to min_related_chunks + ratio = i / (n - 1) if n > 1 else 0 + expected = max_related_chunks - ratio * ( + max_related_chunks - min_related_chunks + ) + expected_counts.append(int(round(expected))) + + # First round allocation: allocate by expected values + selected_chunks = [] + used_counts = [] # Track number of chunks used by each entity + total_remaining = 0 # Accumulate remaining quotas + + for i, entity_rel in enumerate(entities_or_relations): + entity_chunks = entity_rel.get("sorted_chunks", []) + expected = expected_counts[i] + + # Actual allocatable count + actual = min(expected, len(entity_chunks)) + selected_chunks.extend(entity_chunks[:actual]) + used_counts.append(actual) + + # Accumulate remaining quota + remaining = expected - actual + if remaining > 0: + total_remaining += remaining + + # Second round allocation: multi-round scanning to allocate remaining quotas + for _ in range(total_remaining): + allocated = False + + # Scan entities one by one, allocate one chunk when finding unused chunks + for i, entity_rel in enumerate(entities_or_relations): + entity_chunks = entity_rel.get("sorted_chunks", []) + + # Check if there are still unused chunks + if used_counts[i] < len(entity_chunks): + # Allocate one chunk + selected_chunks.append(entity_chunks[used_counts[i]]) + used_counts[i] += 1 + allocated = True + break + + # If no chunks were allocated in this round, all entities are exhausted + if not allocated: + break + + return selected_chunks + + class TokenTracker: """Track token usage for LLM calls.""" @@ -1728,3 +1777,127 @@ class TokenTracker: f"Completion tokens: {usage['completion_tokens']}, " f"Total tokens: {usage['total_tokens']}" ) + + +async def apply_rerank_if_enabled( + query: str, + retrieved_docs: list[dict], + global_config: dict, + enable_rerank: bool = True, + top_n: int = None, +) -> list[dict]: + """ + Apply reranking to retrieved documents if rerank is enabled. + + Args: + query: The search query + retrieved_docs: List of retrieved documents + global_config: Global configuration containing rerank settings + enable_rerank: Whether to enable reranking from query parameter + top_n: Number of top documents to return after reranking + + Returns: + Reranked documents if rerank is enabled, otherwise original documents + """ + if not enable_rerank or not retrieved_docs: + return retrieved_docs + + rerank_func = global_config.get("rerank_model_func") + if not rerank_func: + logger.warning( + "Rerank is enabled but no rerank model is configured. Please set up a rerank model or set enable_rerank=False in query parameters." + ) + return retrieved_docs + + try: + # Apply reranking - let rerank_model_func handle top_k internally + reranked_docs = await rerank_func( + query=query, + documents=retrieved_docs, + top_n=top_n, + ) + if reranked_docs and len(reranked_docs) > 0: + if len(reranked_docs) > top_n: + reranked_docs = reranked_docs[:top_n] + logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks") + return reranked_docs + else: + logger.warning("Rerank returned empty results, using original chunks") + return retrieved_docs + + except Exception as e: + logger.error(f"Error during reranking: {e}, using original chunks") + return retrieved_docs + + +async def process_chunks_unified( + query: str, + unique_chunks: list[dict], + query_param: "QueryParam", + global_config: dict, + source_type: str = "mixed", + chunk_token_limit: int = None, # Add parameter for dynamic token limit +) -> list[dict]: + """ + Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation. + + Args: + query: Search query for reranking + chunks: List of text chunks to process + query_param: Query parameters containing configuration + global_config: Global configuration dictionary + source_type: Source type for logging ("vector", "entity", "relationship", "mixed") + chunk_token_limit: Dynamic token limit for chunks (if None, uses default) + + Returns: + Processed and filtered list of text chunks + """ + if not unique_chunks: + return [] + + origin_count = len(unique_chunks) + + # 1. Apply reranking if enabled and query is provided + if query_param.enable_rerank and query and unique_chunks: + rerank_top_k = query_param.chunk_top_k or len(unique_chunks) + unique_chunks = await apply_rerank_if_enabled( + query=query, + retrieved_docs=unique_chunks, + global_config=global_config, + enable_rerank=query_param.enable_rerank, + top_n=rerank_top_k, + ) + + # 2. Apply chunk_top_k limiting if specified + if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: + if len(unique_chunks) > query_param.chunk_top_k: + unique_chunks = unique_chunks[: query_param.chunk_top_k] + logger.debug( + f"Kept chunk_top-k: {len(unique_chunks)} chunks (deduplicated original: {origin_count})" + ) + + # 3. Token-based final truncation + tokenizer = global_config.get("tokenizer") + if tokenizer and unique_chunks: + # Set default chunk_token_limit if not provided + if chunk_token_limit is None: + # Get default from query_param or global_config + chunk_token_limit = getattr( + query_param, + "max_total_tokens", + global_config.get("MAX_TOTAL_TOKENS", 32000), + ) + + original_count = len(unique_chunks) + unique_chunks = truncate_list_by_token_size( + unique_chunks, + key=lambda x: x.get("content", ""), + max_token_size=chunk_token_limit, + tokenizer=tokenizer, + ) + logger.debug( + f"Token truncation: {len(unique_chunks)} chunks from {original_count} " + f"(chunk available tokens: {chunk_token_limit}, source: {source_type})" + ) + + return unique_chunks