diff --git a/lightrag/operate.py b/lightrag/operate.py index 88bb7349..43f20eca 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1185,9 +1185,9 @@ async def _merge_nodes_then_upsert( # Log based on actual LLM usage if llm_was_used: - status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}" + status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}" else: - status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}" + status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: @@ -1302,9 +1302,9 @@ async def _merge_edges_then_upsert( # Log based on actual LLM usage if llm_was_used: - status_message = f"LLMmrg: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}" + status_message = f"LLMmrg: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}" else: - status_message = f"Merged: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}{dd_message}" + status_message = f"Merged: `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment - already_fragment}{dd_message}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: @@ -2285,7 +2285,7 @@ async def _get_vector_context( return [] -async def _build_query_context( +async def _perform_kg_search( query: str, ll_keywords: str, hl_keywords: str, @@ -2295,25 +2295,21 @@ async def _build_query_context( text_chunks_db: BaseKVStorage, query_param: QueryParam, chunks_vdb: BaseVectorStorage = None, -): - if not query: - logger.warning("Query is empty, skipping context building") - return "" +) -> dict[str, Any]: + """ + Pure search logic that retrieves raw entities, relations, and vector chunks. + No token truncation or formatting - just raw search results. + """ - logger.info(f"Process {os.getpid()} building query context...") - - # Collect chunks from different sources separately - vector_chunks = [] - entity_chunks = [] - relation_chunks = [] - entities_context = [] - relations_context = [] - - # Store original data for later text chunk retrieval + # Initialize result containers local_entities = [] local_relations = [] global_entities = [] global_relations = [] + vector_chunks = [] + chunk_tracking = {} + + # Handle different query modes # Track chunk sources and metadata for final logging chunk_tracking = {} # chunk_id -> {source, frequency, order} @@ -2369,7 +2365,7 @@ async def _build_query_context( query_param, ) - # Get vector chunks first if in mix mode + # Get vector chunks for mix mode if query_param.mode == "mix" and chunks_vdb: vector_chunks = await _get_vector_context( query, @@ -2389,11 +2385,9 @@ async def _build_query_context( else: logger.warning(f"Vector chunk missing chunk_id: {chunk}") - # Use round-robin merge to combine local and global data fairly + # Round-robin merge entities final_entities = [] seen_entities = set() - - # Round-robin merge entities max_len = max(len(local_entities), len(global_entities)) for i in range(max_len): # First from local @@ -2415,7 +2409,6 @@ async def _build_query_context( # Round-robin merge relations final_relations = [] seen_relations = set() - max_len = max(len(local_relations), len(global_relations)) for i in range(max_len): # First from local @@ -2448,91 +2441,107 @@ async def _build_query_context( final_relations.append(relation) seen_relations.add(rel_key) - # Generate entities context + logger.info( + f"Raw search results: {len(final_entities)} entities, {len(final_relations)} relations, {len(vector_chunks)} vector chunks" + ) + + return { + "final_entities": final_entities, + "final_relations": final_relations, + "vector_chunks": vector_chunks, + "chunk_tracking": chunk_tracking, + "query_embedding": query_embedding, + } + + +async def _apply_token_truncation( + search_result: dict[str, Any], + query_param: QueryParam, + global_config: dict[str, str], +) -> dict[str, Any]: + """ + Apply token-based truncation to entities and relations for LLM efficiency. + This function is only used by kg_query, not kg_search. + """ + tokenizer = global_config.get("tokenizer") + if not tokenizer: + logger.warning("No tokenizer found, skipping truncation") + return { + "truncated_entities": search_result["final_entities"], + "truncated_relations": search_result["final_relations"], + "entities_context": [], + "relations_context": [], + "filtered_entities": search_result["final_entities"], + "filtered_relations": search_result["final_relations"], + } + + # Get token limits from query_param with fallbacks + max_entity_tokens = getattr( + query_param, + "max_entity_tokens", + global_config.get("max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS), + ) + max_relation_tokens = getattr( + query_param, + "max_relation_tokens", + global_config.get("max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS), + ) + + final_entities = search_result["final_entities"] + final_relations = search_result["final_relations"] + + # Generate entities context for truncation entities_context = [] - for i, n in enumerate(final_entities): - created_at = n.get("created_at", "UNKNOWN") + for i, entity in enumerate(final_entities): + created_at = entity.get("created_at", "UNKNOWN") if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) - # Get file path from node data - file_path = n.get("file_path", "unknown_source") - entities_context.append( { "id": i + 1, - "entity": n["entity_name"], - "type": n.get("entity_type", "UNKNOWN"), - "description": n.get("description", "UNKNOWN"), + "entity": entity["entity_name"], + "type": entity.get("entity_type", "UNKNOWN"), + "description": entity.get("description", "UNKNOWN"), "created_at": created_at, - "file_path": file_path, + "file_path": entity.get("file_path", "unknown_source"), } ) - # Generate relations context + # Generate relations context for truncation relations_context = [] - for i, e in enumerate(final_relations): - created_at = e.get("created_at", "UNKNOWN") - # Convert timestamp to readable format + for i, relation in enumerate(final_relations): + created_at = relation.get("created_at", "UNKNOWN") if isinstance(created_at, (int, float)): created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) - # Get file path from edge data - file_path = e.get("file_path", "unknown_source") - # Handle different relation data formats - if "src_tgt" in e: - entity1, entity2 = e["src_tgt"] + if "src_tgt" in relation: + entity1, entity2 = relation["src_tgt"] else: - entity1, entity2 = e.get("src_id"), e.get("tgt_id") + entity1, entity2 = relation.get("src_id"), relation.get("tgt_id") relations_context.append( { "id": i + 1, "entity1": entity1, "entity2": entity2, - "description": e.get("description", "UNKNOWN"), + "description": relation.get("description", "UNKNOWN"), "created_at": created_at, - "file_path": file_path, + "file_path": relation.get("file_path", "unknown_source"), } ) logger.debug( - f"Initial KG query results: {len(entities_context)} entities, {len(relations_context)} relations" + f"Before truncation: {len(entities_context)} entities, {len(relations_context)} relations" ) - # Unified token control system - Apply precise token limits to entities and relations - tokenizer = text_chunks_db.global_config.get("tokenizer") - # Get new token limits from query_param (with fallback to global_config) - max_entity_tokens = getattr( - query_param, - "max_entity_tokens", - text_chunks_db.global_config.get( - "max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS - ), - ) - max_relation_tokens = getattr( - query_param, - "max_relation_tokens", - text_chunks_db.global_config.get( - "max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS - ), - ) - max_total_tokens = getattr( - query_param, - "max_total_tokens", - text_chunks_db.global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS), - ) - - # Truncate entities based on complete JSON serialization + # Apply token-based truncation if entities_context: - # Process entities context to replace GRAPH_FIELD_SEP with : in file_path fields + # Remove file_path and created_at for token calculation for entity in entities_context: - # remove file_path and created_at entity.pop("file_path", None) entity.pop("created_at", None) - # if "file_path" in entity and entity["file_path"]: - # entity["file_path"] = entity["file_path"].replace(GRAPH_FIELD_SEP, ";") entities_context = truncate_list_by_token_size( entities_context, @@ -2541,17 +2550,11 @@ async def _build_query_context( tokenizer=tokenizer, ) - # Truncate relations based on complete JSON serialization if relations_context: - # Process relations context to replace GRAPH_FIELD_SEP with : in file_path fields + # Remove file_path and created_at for token calculation for relation in relations_context: - # remove file_path and created_at relation.pop("file_path", None) relation.pop("created_at", None) - # if "file_path" in relation and relation["file_path"]: - # relation["file_path"] = relation["file_path"].replace( - # GRAPH_FIELD_SEP, ";" - # ) relations_context = truncate_list_by_token_size( relations_context, @@ -2560,41 +2563,86 @@ async def _build_query_context( tokenizer=tokenizer, ) - # After truncation, get text chunks based on final entities and relations logger.info( - f"Truncated KG query results: {len(entities_context)} entities, {len(relations_context)} relations" + f"After truncation: {len(entities_context)} entities, {len(relations_context)} relations" ) - # Create filtered data based on truncated context - final_node_datas = [] - if entities_context and final_entities: + # Create filtered original data based on truncated context + filtered_entities = [] + if entities_context: final_entity_names = {e["entity"] for e in entities_context} seen_nodes = set() - for node in final_entities: - name = node.get("entity_name") + for entity in final_entities: + name = entity.get("entity_name") if name in final_entity_names and name not in seen_nodes: - final_node_datas.append(node) + filtered_entities.append(entity) seen_nodes.add(name) - final_edge_datas = [] - if relations_context and final_relations: + filtered_relations = [] + if relations_context: final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context} seen_edges = set() - for edge in final_relations: - src, tgt = edge.get("src_id"), edge.get("tgt_id") + for relation in final_relations: + src, tgt = relation.get("src_id"), relation.get("tgt_id") if src is None or tgt is None: - src, tgt = edge.get("src_tgt", (None, None)) + src, tgt = relation.get("src_tgt", (None, None)) pair = (src, tgt) if pair in final_relation_pairs and pair not in seen_edges: - final_edge_datas.append(edge) + filtered_relations.append(relation) seen_edges.add(pair) - # Get text chunks based on final filtered data - # To preserve the influence of entity order, entiy-based chunks should not be deduplcicated by vector_chunks - if final_node_datas: + return { + "truncated_entities": final_entities, # Keep original for backward compatibility + "truncated_relations": final_relations, # Keep original for backward compatibility + "entities_context": entities_context, # Formatted and truncated for LLM + "relations_context": relations_context, # Formatted and truncated for LLM + "filtered_entities": filtered_entities, # Original entities that passed truncation + "filtered_relations": filtered_relations, # Original relations that passed truncation + } + + +async def _merge_all_chunks( + search_result: dict[str, Any], + filtered_entities: list[dict] = None, + filtered_relations: list[dict] = None, + query: str = "", + knowledge_graph_inst: BaseGraphStorage = None, + text_chunks_db: BaseKVStorage = None, + query_param: QueryParam = None, + chunks_vdb: BaseVectorStorage = None, + chunk_tracking: dict = None, +) -> list[dict]: + """ + Merge chunks from different sources: vector_chunks + entity_chunks + relation_chunks. + + For kg_search: uses all original entities/relations + For kg_query: uses filtered entities/relations based on token truncation + """ + if chunk_tracking is None: + chunk_tracking = search_result.get("chunk_tracking", {}) + + # Use filtered entities/relations if provided (kg_query), otherwise use all (kg_search) + entities_to_use = ( + filtered_entities + if filtered_entities is not None + else search_result["final_entities"] + ) + relations_to_use = ( + filtered_relations + if filtered_relations is not None + else search_result["final_relations"] + ) + vector_chunks = search_result["vector_chunks"] + + # Get chunks from entities + entity_chunks = [] + if entities_to_use and text_chunks_db: + # Pre-compute query embedding if needed + query_embedding = search_result.get("query_embedding", None) + entity_chunks = await _find_related_text_unit_from_entities( - final_node_datas, + entities_to_use, query_param, text_chunks_db, knowledge_graph_inst, @@ -2604,21 +2652,21 @@ async def _build_query_context( query_embedding=query_embedding, ) - # Find deduplcicated chunks from edge - # Deduplication cause chunks solely relation-based to be prioritized and sent to the LLM when re-ranking is disabled - if final_edge_datas: + # Get chunks from relations + relation_chunks = [] + if relations_to_use and text_chunks_db: relation_chunks = await _find_related_text_unit_from_relations( - final_edge_datas, + relations_to_use, query_param, text_chunks_db, - entity_chunks, + entity_chunks, # For deduplication query, chunks_vdb, chunk_tracking=chunk_tracking, query_embedding=query_embedding, ) - # Round-robin merge chunks from different sources with deduplication by chunk_id + # Round-robin merge chunks from different sources with deduplication merged_chunks = [] seen_chunk_ids = set() max_len = max(len(vector_chunks), len(entity_chunks), len(relation_chunks)) @@ -2668,12 +2716,238 @@ async def _build_query_context( ) logger.info( - f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}" + f"Round-robin merged chunks: {origin_len} -> {len(merged_chunks)} (deduplication: {origin_len - len(merged_chunks)})" + ) + + return merged_chunks + + +async def kg_search( + query: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage, + query_param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + chunks_vdb: BaseVectorStorage = None, +) -> dict[str, Any]: + """Search knowledge graph and return structured results without LLM generation + + For kg_search: Search + Merge chunks (NO truncation) + Returns complete search results for user analysis + """ + if not query: + return { + "entities": [], + "relationships": [], + "chunks": [], + "metadata": { + "query_mode": query_param.mode, + "keywords": {"high_level": [], "low_level": []}, + }, + } + + # Handle cache (reuse existing cache logic but for search results) + args_hash = compute_args_hash( + query_param.mode, + query, + "search", # Different cache key for search vs query + query_param.top_k, + query_param.chunk_top_k, + query_param.max_entity_tokens, + query_param.max_relation_tokens, + query_param.max_total_tokens, + query_param.hl_keywords or [], + query_param.ll_keywords or [], + query_param.user_prompt or "", + query_param.enable_rerank, + ) + cached_response = await handle_cache( + hashing_kv, args_hash, query, query_param.mode, cache_type="search" + ) + if cached_response is not None: + try: + return json_repair.loads(cached_response) + except (json.JSONDecodeError, KeyError): + logger.warning( + "Invalid cache format for search results, proceeding with fresh search" + ) + + # Get keywords (reuse existing logic) + hl_keywords, ll_keywords = await get_keywords_from_query( + query, query_param, global_config, hashing_kv + ) + + logger.debug(f"High-level keywords: {hl_keywords}") + logger.debug(f"Low-level keywords: {ll_keywords}") + + # Handle empty keywords (reuse existing logic) + if ll_keywords == [] and query_param.mode in ["local", "hybrid", "mix"]: + logger.warning("low_level_keywords is empty") + if hl_keywords == [] and query_param.mode in ["global", "hybrid", "mix"]: + logger.warning("high_level_keywords is empty") + if hl_keywords == [] and ll_keywords == []: + if len(query) < 50: + logger.warning(f"Forced low_level_keywords to origin query: {query}") + ll_keywords = [query] + else: + return { + "entities": [], + "relationships": [], + "chunks": [], + "metadata": { + "query_mode": query_param.mode, + "keywords": {"high_level": hl_keywords, "low_level": ll_keywords}, + "error": "Keywords extraction failed", + }, + } + + ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" + hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" + + # Stage 1: Pure search (no truncation for kg_search) + search_result = await _perform_kg_search( + query, + ll_keywords_str, + hl_keywords_str, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + query_param, + chunks_vdb, + ) + + if not search_result["final_entities"] and not search_result["final_relations"]: + return { + "entities": [], + "relationships": [], + "chunks": [], + "metadata": { + "query_mode": query_param.mode, + "keywords": {"high_level": hl_keywords, "low_level": ll_keywords}, + "error": "No valid results found", + }, + } + + # Stage 2: Merge ALL chunks (no filtering, use all entities/relations) + merged_chunks = await _merge_all_chunks( + search_result, + filtered_entities=None, # Use ALL entities (no filtering) + filtered_relations=None, # Use ALL relations (no filtering) + query=query, + knowledge_graph_inst=knowledge_graph_inst, + text_chunks_db=text_chunks_db, + query_param=query_param, + chunks_vdb=chunks_vdb, + chunk_tracking=search_result["chunk_tracking"], + ) + + # Build final structured result + final_result = { + "entities": search_result["final_entities"], + "relationships": search_result["final_relations"], + "chunks": merged_chunks, + "metadata": { + "query_mode": query_param.mode, + "keywords": {"high_level": hl_keywords, "low_level": ll_keywords}, + }, + } + + # Cache the results + if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): + queryparam_dict = { + "mode": query_param.mode, + "response_type": "search", + "top_k": query_param.top_k, + "chunk_top_k": query_param.chunk_top_k, + "max_entity_tokens": query_param.max_entity_tokens, + "max_relation_tokens": query_param.max_relation_tokens, + "max_total_tokens": query_param.max_total_tokens, + "hl_keywords": query_param.hl_keywords or [], + "ll_keywords": query_param.ll_keywords or [], + "user_prompt": query_param.user_prompt or "", + "enable_rerank": query_param.enable_rerank, + } + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=json.dumps(final_result, ensure_ascii=False), + prompt=query, + mode=query_param.mode, + cache_type="search", + queryparam=queryparam_dict, + ), + ) + + return final_result + + +async def _build_llm_context( + entities_context: list[dict], + relations_context: list[dict], + merged_chunks: list[dict], + query: str, + query_param: QueryParam, + global_config: dict[str, str], + chunk_tracking: dict = None, +) -> str: + """ + Build the final LLM context string with token processing. + This includes dynamic token calculation and final chunk truncation. + """ + tokenizer = global_config.get("tokenizer") + if not tokenizer: + logger.warning("No tokenizer found, building context without token limits") + + # Build basic context without token processing + entities_str = json.dumps(entities_context, ensure_ascii=False) + relations_str = json.dumps(relations_context, ensure_ascii=False) + + text_units_context = [] + for i, chunk in enumerate(merged_chunks): + text_units_context.append( + { + "id": i + 1, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } + ) + + text_units_str = json.dumps(text_units_context, ensure_ascii=False) + + return f"""-----Entities(KG)----- + +```json +{entities_str} +``` + +-----Relationships(KG)----- + +```json +{relations_str} +``` + +-----Document Chunks(DC)----- + +```json +{text_units_str} +``` + +""" + + # Get token limits + max_total_tokens = getattr( + query_param, + "max_total_tokens", + global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS), ) - # Apply token processing to merged chunks text_units_context = [] truncated_chunks = [] + if merged_chunks: # Calculate dynamic token limit for text chunks entities_str = json.dumps(entities_context, ensure_ascii=False) @@ -2704,18 +2978,7 @@ async def _build_query_context( ) kg_context_tokens = len(tokenizer.encode(kg_context)) - # Calculate actual system prompt overhead dynamically - # 1. Converstion history not included in context length calculation - history_context = "" - # if query_param.conversation_history: - # history_context = get_conversation_turns( - # query_param.conversation_history, query_param.history_turns - # ) - # history_tokens = ( - # len(tokenizer.encode(history_context)) if history_context else 0 - # ) - - # 2. Calculate system prompt template tokens (excluding context_data) + # Calculate system prompt template overhead user_prompt = query_param.user_prompt if query_param.user_prompt else "" response_type = ( query_param.response_type @@ -2723,14 +2986,14 @@ async def _build_query_context( else "Multiple Paragraphs" ) - # Get the system prompt template from PROMPTS - sys_prompt_template = text_chunks_db.global_config.get( + # Get the system prompt template from PROMPTS or global_config + sys_prompt_template = global_config.get( "system_prompt_template", PROMPTS["rag_response"] ) - # Create a sample system prompt with placeholders filled (excluding context_data) + # Create sample system prompt for overhead calculation sample_sys_prompt = sys_prompt_template.format( - history=history_context, + history="", # History not included in context length calculation context_data="", # Empty for overhead calculation response_type=response_type, user_prompt=user_prompt, @@ -2756,7 +3019,7 @@ async def _build_query_context( query=query, unique_chunks=merged_chunks, query_param=query_param, - global_config=text_chunks_db.global_config, + global_config=global_config, source_type=query_param.mode, chunk_token_limit=available_chunk_tokens, # Pass dynamic limit ) @@ -2827,6 +3090,76 @@ async def _build_query_context( return result +# Now let's update the old _build_query_context to use the new architecture +async def _build_query_context( + query: str, + ll_keywords: str, + hl_keywords: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage, + query_param: QueryParam, + chunks_vdb: BaseVectorStorage = None, +) -> str: + """ + Main query context building function using the new 4-stage architecture: + 1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context + """ + + if not query: + logger.warning("Query is empty, skipping context building") + return "" + + # Stage 1: Pure search + search_result = await _perform_kg_search( + query, + ll_keywords, + hl_keywords, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + query_param, + chunks_vdb, + ) + + if not search_result["final_entities"] and not search_result["final_relations"]: + return None + + # Stage 2: Apply token truncation for LLM efficiency + truncation_result = await _apply_token_truncation( + search_result, + query_param, + text_chunks_db.global_config, + ) + + # Stage 3: Merge chunks using filtered entities/relations + merged_chunks = await _merge_all_chunks( + search_result, + filtered_entities=truncation_result["filtered_entities"], + filtered_relations=truncation_result["filtered_relations"], + query=query, + knowledge_graph_inst=knowledge_graph_inst, + text_chunks_db=text_chunks_db, + query_param=query_param, + chunks_vdb=chunks_vdb, + chunk_tracking=search_result["chunk_tracking"], + ) + + # Stage 4: Build final LLM context with dynamic token processing + context = await _build_llm_context( + entities_context=truncation_result["entities_context"], + relations_context=truncation_result["relations_context"], + merged_chunks=merged_chunks, + query=query, + query_param=query_param, + global_config=text_chunks_db.global_config, + chunk_tracking=search_result["chunk_tracking"], + ) + + return context + + async def _get_node_data( query: str, knowledge_graph_inst: BaseGraphStorage,