diff --git a/lightrag/operate.py b/lightrag/operate.py index 47bc2b3e..d8152050 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1892,26 +1892,42 @@ async def _build_query_context( entities_context = [] relations_context = [] + # Store original data for later text chunk retrieval + original_node_datas = [] + original_edge_datas = [] + # Handle local and global modes if query_param.mode == "local": - entities_context, relations_context, entity_chunks = await _get_node_data( + ( + entities_context, + relations_context, + node_datas, + use_relations, + ) = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) - all_chunks.extend(entity_chunks) + original_node_datas = node_datas + original_edge_datas = use_relations elif query_param.mode == "global": - entities_context, relations_context, relationship_chunks = await _get_edge_data( + ( + entities_context, + relations_context, + edge_datas, + use_entities, + ) = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, text_chunks_db, query_param, ) - all_chunks.extend(relationship_chunks) + original_edge_datas = edge_datas + original_node_datas = use_entities else: # hybrid or mix mode ll_data = await _get_node_data( @@ -1929,10 +1945,13 @@ async def _build_query_context( query_param, ) - (ll_entities_context, ll_relations_context, ll_chunks) = ll_data - (hl_entities_context, hl_relations_context, hl_chunks) = hl_data + (ll_entities_context, ll_relations_context, ll_node_datas, ll_edge_datas) = ( + ll_data + ) + (hl_entities_context, hl_relations_context, hl_edge_datas, hl_node_datas) = ( + hl_data + ) - # Collect chunks from entity and relationship sources # Get vector chunks first if in mix mode if query_param.mode == "mix" and chunks_vdb: vector_chunks = await _get_vector_context( @@ -1942,8 +1961,9 @@ async def _build_query_context( ) all_chunks.extend(vector_chunks) - all_chunks.extend(ll_chunks) - all_chunks.extend(hl_chunks) + # Store original data from both sources + original_node_datas = ll_node_datas + hl_node_datas + original_edge_datas = ll_edge_datas + hl_edge_datas # Combine entities and relations contexts entities_context = process_combine_contexts( @@ -2027,6 +2047,73 @@ async def _build_query_context( f"Truncated relations: {original_relation_count} -> {len(relations_context)} (relation max tokens: {max_relation_tokens})" ) + # After truncation, get text chunks based on final entities and relations + logger.info("Getting text chunks based on truncated entities and relations...") + + # Create filtered data based on truncated context + final_node_datas = [] + final_edge_datas = [] + + if entities_context and original_node_datas: + # Create a set of entity names from final truncated context + final_entity_names = {entity["entity"] for entity in entities_context} + # Filter original node data based on final entities + final_node_datas = [ + node + for node in original_node_datas + if node.get("entity_name") in final_entity_names + ] + + if relations_context and original_edge_datas: + # Create a set of relation pairs from final truncated context + final_relation_pairs = { + (rel["entity1"], rel["entity2"]) for rel in relations_context + } + # Filter original edge data based on final relations + final_edge_datas = [ + edge + for edge in original_edge_datas + if (edge.get("src_id"), edge.get("tgt_id")) in final_relation_pairs + or ( + edge.get("src_tgt", (None, None))[0], + edge.get("src_tgt", (None, None))[1], + ) + in final_relation_pairs + ] + + # Get text chunks based on final filtered data + text_chunk_tasks = [] + + if final_node_datas: + text_chunk_tasks.append( + _find_most_related_text_unit_from_entities( + final_node_datas, + query_param, + text_chunks_db, + knowledge_graph_inst, + ) + ) + + if final_edge_datas: + text_chunk_tasks.append( + _find_related_text_unit_from_relationships( + final_edge_datas, + query_param, + text_chunks_db, + knowledge_graph_inst, + ) + ) + + # Execute text chunk retrieval in parallel + if text_chunk_tasks: + text_chunk_results = await asyncio.gather(*text_chunk_tasks) + for chunks in text_chunk_results: + if chunks: + all_chunks.extend(chunks) + + # Apply token processing to chunks if tokenizer is available + text_units_context = [] + if tokenizer and all_chunks: # Calculate dynamic token limit for text chunks entities_str = json.dumps(entities_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False) @@ -2122,7 +2209,6 @@ async def _build_query_context( ) # Rebuild text_units_context with truncated chunks - text_units_context = [] for i, chunk in enumerate(truncated_chunks): text_units_context.append( { @@ -2187,7 +2273,7 @@ async def _get_node_data( ) if not len(results): - return "", "", "" + return "", "", [], [] # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] @@ -2214,14 +2300,8 @@ async def _get_node_data( } for k, n, d in zip(results, node_datas, node_degrees) if n is not None - ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. - # get entitytext chunk - use_text_units = await _find_most_related_text_unit_from_entities( - node_datas, - query_param, - text_chunks_db, - knowledge_graph_inst, - ) + ] + use_relations = await _find_most_related_edges_from_entities( node_datas, query_param, @@ -2229,7 +2309,7 @@ async def _get_node_data( ) logger.info( - f"Local query: {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks" + f"Local query: {len(node_datas)} entites, {len(use_relations)} relations" ) # build prompt @@ -2278,7 +2358,7 @@ async def _get_node_data( } ) - return entities_context, relations_context, use_text_units + return entities_context, relations_context, node_datas, use_relations async def _find_most_related_text_unit_from_entities( @@ -2456,7 +2536,7 @@ async def _get_edge_data( ) if not len(results): - return "", "", "" + return "", "", [], [] # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. @@ -2495,21 +2575,15 @@ async def _get_edge_data( edge_datas = sorted( edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True ) - use_entities, use_text_units = await asyncio.gather( - _find_most_related_entities_from_relationships( - edge_datas, - query_param, - knowledge_graph_inst, - ), - _find_related_text_unit_from_relationships( - edge_datas, - query_param, - text_chunks_db, - knowledge_graph_inst, - ), + + use_entities = await _find_most_related_entities_from_relationships( + edge_datas, + query_param, + knowledge_graph_inst, ) + logger.info( - f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks" + f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations" ) relations_context = [] @@ -2558,16 +2632,8 @@ async def _get_edge_data( } ) - text_units_context = [] - for i, t in enumerate(use_text_units): - text_units_context.append( - { - "id": i + 1, - "content": t["content"], - "file_path": t.get("file_path", "unknown"), - } - ) - return entities_context, relations_context, text_units_context + # Return original data for later text chunk retrieval + return entities_context, relations_context, edge_datas, use_entities async def _find_most_related_entities_from_relationships(