From e56734cb8b115d610e55d988ceee5d479e34c29a Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 3 Jul 2025 04:18:25 +0800 Subject: [PATCH] Refac: Optimize document deletion performance - Adding chunks_list to dock_status - Adding llm_cache_list to text_chunks - Implemented storage types: JsonKV and Redis --- lightrag/base.py | 2 + lightrag/kg/json_doc_status_impl.py | 4 + lightrag/kg/json_kv_impl.py | 5 ++ lightrag/kg/redis_impl.py | 11 +++ lightrag/lightrag.py | 79 +++++++++--------- lightrag/operate.py | 121 +++++++++++++++++++++------- lightrag/utils.py | 57 +++++++++++++ 7 files changed, 208 insertions(+), 71 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 36c3ff59..7820b4da 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -634,6 +634,8 @@ class DocProcessingStatus: """ISO format timestamp when document was last updated""" chunks_count: int | None = None """Number of chunks after splitting, used for processing""" + chunks_list: list[str] | None = field(default_factory=list) + """List of chunk IDs associated with this document, used for deletion""" error: str | None = None """Error message if failed""" metadata: dict[str, Any] = field(default_factory=dict) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index f8387ad8..ab6ab390 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -118,6 +118,10 @@ class JsonDocStatusStorage(DocStatusStorage): return logger.debug(f"Inserting {len(data)} records to {self.namespace}") async with self._storage_lock: + # Ensure chunks_list field exists for new documents + for doc_id, doc_data in data.items(): + if "chunks_list" not in doc_data: + doc_data["chunks_list"] = [] self._data.update(data) await set_all_update_flags(self.namespace) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index d6e2cb70..0d925aaf 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -109,6 +109,11 @@ class JsonKVStorage(BaseKVStorage): return logger.debug(f"Inserting {len(data)} records to {self.namespace}") async with self._storage_lock: + # For text_chunks namespace, ensure llm_cache_list field exists + if "text_chunks" in self.namespace: + for chunk_id, chunk_data in data.items(): + if "llm_cache_list" not in chunk_data: + chunk_data["llm_cache_list"] = [] self._data.update(data) await set_all_update_flags(self.namespace) diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 5be9f0e6..89da0ef4 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -202,6 +202,12 @@ class RedisKVStorage(BaseKVStorage): return async with self._get_redis_connection() as redis: try: + # For text_chunks namespace, ensure llm_cache_list field exists + if "text_chunks" in self.namespace: + for chunk_id, chunk_data in data.items(): + if "llm_cache_list" not in chunk_data: + chunk_data["llm_cache_list"] = [] + pipe = redis.pipeline() for k, v in data.items(): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) @@ -601,6 +607,11 @@ class RedisDocStatusStorage(DocStatusStorage): logger.debug(f"Inserting {len(data)} records to {self.namespace}") async with self._get_redis_connection() as redis: try: + # Ensure chunks_list field exists for new documents + for doc_id, doc_data in data.items(): + if "chunks_list" not in doc_data: + doc_data["chunks_list"] = [] + pipe = redis.pipeline() for k, v in data.items(): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 132075d6..ac38dc4c 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -349,6 +349,7 @@ class LightRAG: # Fix global_config now global_config = asdict(self) + _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") @@ -952,6 +953,7 @@ class LightRAG: **dp, "full_doc_id": doc_id, "file_path": file_path, # Add file path to each chunk + "llm_cache_list": [], # Initialize empty LLM cache list for each chunk } for dp in self.chunking_func( self.tokenizer, @@ -963,14 +965,17 @@ class LightRAG: ) } - # Process document (text chunks and full docs) in parallel - # Create tasks with references for potential cancellation + # Process document in two stages + # Stage 1: Process text chunks and docs (parallel execution) doc_status_task = asyncio.create_task( self.doc_status.upsert( { doc_id: { "status": DocStatus.PROCESSING, "chunks_count": len(chunks), + "chunks_list": list( + chunks.keys() + ), # Save chunks list "content": status_doc.content, "content_summary": status_doc.content_summary, "content_length": status_doc.content_length, @@ -986,11 +991,6 @@ class LightRAG: chunks_vdb_task = asyncio.create_task( self.chunks_vdb.upsert(chunks) ) - entity_relation_task = asyncio.create_task( - self._process_entity_relation_graph( - chunks, pipeline_status, pipeline_status_lock - ) - ) full_docs_task = asyncio.create_task( self.full_docs.upsert( {doc_id: {"content": status_doc.content}} @@ -999,14 +999,26 @@ class LightRAG: text_chunks_task = asyncio.create_task( self.text_chunks.upsert(chunks) ) - tasks = [ + + # First stage tasks (parallel execution) + first_stage_tasks = [ doc_status_task, chunks_vdb_task, - entity_relation_task, full_docs_task, text_chunks_task, ] - await asyncio.gather(*tasks) + entity_relation_task = None + + # Execute first stage tasks + await asyncio.gather(*first_stage_tasks) + + # Stage 2: Process entity relation graph (after text_chunks are saved) + entity_relation_task = asyncio.create_task( + self._process_entity_relation_graph( + chunks, pipeline_status, pipeline_status_lock + ) + ) + await entity_relation_task file_extraction_stage_ok = True except Exception as e: @@ -1021,14 +1033,14 @@ class LightRAG: ) pipeline_status["history_messages"].append(error_msg) - # Cancel other tasks as they are no longer meaningful - for task in [ - chunks_vdb_task, - entity_relation_task, - full_docs_task, - text_chunks_task, - ]: - if not task.done(): + # Cancel tasks that are not yet completed + all_tasks = first_stage_tasks + ( + [entity_relation_task] + if entity_relation_task + else [] + ) + for task in all_tasks: + if task and not task.done(): task.cancel() # Persistent llm cache @@ -1078,6 +1090,9 @@ class LightRAG: doc_id: { "status": DocStatus.PROCESSED, "chunks_count": len(chunks), + "chunks_list": list( + chunks.keys() + ), # 保留 chunks_list "content": status_doc.content, "content_summary": status_doc.content_summary, "content_length": status_doc.content_length, @@ -1196,6 +1211,7 @@ class LightRAG: pipeline_status=pipeline_status, pipeline_status_lock=pipeline_status_lock, llm_response_cache=self.llm_response_cache, + text_chunks_storage=self.text_chunks, ) return chunk_results except Exception as e: @@ -1726,28 +1742,10 @@ class LightRAG: file_path="", ) - # 2. Get all chunks related to this document - try: - all_chunks = await self.text_chunks.get_all() - related_chunks = { - chunk_id: chunk_data - for chunk_id, chunk_data in all_chunks.items() - if isinstance(chunk_data, dict) - and chunk_data.get("full_doc_id") == doc_id - } + # 2. Get chunk IDs from document status + chunk_ids = set(doc_status_data.get("chunks_list", [])) - # Update pipeline status after getting chunks count - async with pipeline_status_lock: - log_message = f"Retrieved {len(related_chunks)} of {len(all_chunks)} related chunks" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - except Exception as e: - logger.error(f"Failed to retrieve chunks for document {doc_id}: {e}") - raise Exception(f"Failed to retrieve document chunks: {e}") from e - - if not related_chunks: + if not chunk_ids: logger.warning(f"No chunks found for document {doc_id}") # Mark that deletion operations have started deletion_operations_started = True @@ -1778,7 +1776,6 @@ class LightRAG: file_path=file_path, ) - chunk_ids = set(related_chunks.keys()) # Mark that deletion operations have started deletion_operations_started = True @@ -1943,7 +1940,7 @@ class LightRAG: knowledge_graph_inst=self.chunk_entity_relation_graph, entities_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, - text_chunks=self.text_chunks, + text_chunks_storage=self.text_chunks, llm_response_cache=self.llm_response_cache, global_config=asdict(self), pipeline_status=pipeline_status, diff --git a/lightrag/operate.py b/lightrag/operate.py index bd70ceed..eacccb98 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -25,6 +25,7 @@ from .utils import ( CacheData, get_conversation_turns, use_llm_func_with_cache, + update_chunk_cache_list, ) from .base import ( BaseGraphStorage, @@ -103,8 +104,6 @@ async def _handle_entity_relation_summary( entity_or_relation_name: str, description: str, global_config: dict, - pipeline_status: dict = None, - pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ) -> str: """Handle entity relation summary @@ -247,7 +246,7 @@ async def _rebuild_knowledge_from_chunks( knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, - text_chunks: BaseKVStorage, + text_chunks_storage: BaseKVStorage, llm_response_cache: BaseKVStorage, global_config: dict[str, str], pipeline_status: dict | None = None, @@ -261,6 +260,7 @@ async def _rebuild_knowledge_from_chunks( Args: entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids + text_chunks_data: Pre-loaded chunk data dict {chunk_id: chunk_data} """ if not entities_to_rebuild and not relationships_to_rebuild: return @@ -273,6 +273,8 @@ async def _rebuild_knowledge_from_chunks( all_referenced_chunk_ids.update(chunk_ids) for chunk_ids in relationships_to_rebuild.values(): all_referenced_chunk_ids.update(chunk_ids) + # sort all_referenced_chunk_ids to get a stable order in merge stage + all_referenced_chunk_ids = sorted(all_referenced_chunk_ids) status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions" logger.info(status_message) @@ -281,9 +283,11 @@ async def _rebuild_knowledge_from_chunks( pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) - # Get cached extraction results for these chunks + # Get cached extraction results for these chunks using storage cached_results = await _get_cached_extraction_results( - llm_response_cache, all_referenced_chunk_ids + llm_response_cache, + all_referenced_chunk_ids, + text_chunks_storage=text_chunks_storage, ) if not cached_results: @@ -299,15 +303,25 @@ async def _rebuild_knowledge_from_chunks( chunk_entities = {} # chunk_id -> {entity_name: [entity_data]} chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]} - for chunk_id, extraction_result in cached_results.items(): + for chunk_id, extraction_results in cached_results.items(): try: - entities, relationships = await _parse_extraction_result( - text_chunks=text_chunks, - extraction_result=extraction_result, - chunk_id=chunk_id, - ) - chunk_entities[chunk_id] = entities - chunk_relationships[chunk_id] = relationships + # Handle multiple extraction results per chunk + chunk_entities[chunk_id] = defaultdict(list) + chunk_relationships[chunk_id] = defaultdict(list) + + for extraction_result in extraction_results: + entities, relationships = await _parse_extraction_result( + text_chunks_storage=text_chunks_storage, + extraction_result=extraction_result, + chunk_id=chunk_id, + ) + + # Merge entities and relationships from this extraction result + for entity_name, entity_list in entities.items(): + chunk_entities[chunk_id][entity_name].extend(entity_list) + for rel_key, rel_list in relationships.items(): + chunk_relationships[chunk_id][rel_key].extend(rel_list) + except Exception as e: status_message = ( f"Failed to parse cached extraction result for chunk {chunk_id}: {e}" @@ -387,43 +401,76 @@ async def _rebuild_knowledge_from_chunks( async def _get_cached_extraction_results( - llm_response_cache: BaseKVStorage, chunk_ids: set[str] -) -> dict[str, str]: + llm_response_cache: BaseKVStorage, + chunk_ids: set[str], + text_chunks_storage: BaseKVStorage, +) -> dict[str, list[str]]: """Get cached extraction results for specific chunk IDs Args: + llm_response_cache: LLM response cache storage chunk_ids: Set of chunk IDs to get cached results for + text_chunks_data: Pre-loaded chunk data (optional, for performance) + text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None) Returns: - Dict mapping chunk_id -> extraction_result_text + Dict mapping chunk_id -> list of extraction_result_text """ cached_results = {} - # Get all cached data (flattened cache structure) - all_cache = await llm_response_cache.get_all() + # Collect all LLM cache IDs from chunks + all_cache_ids = set() - for cache_key, cache_entry in all_cache.items(): + # Read from storage + chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids)) + for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list): + if chunk_data and isinstance(chunk_data, dict): + llm_cache_list = chunk_data.get("llm_cache_list", []) + if llm_cache_list: + all_cache_ids.update(llm_cache_list) + else: + logger.warning( + f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}" + ) + + if not all_cache_ids: + logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs") + return cached_results + + # Batch get LLM cache entries + cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids)) + + # Process cache entries and group by chunk_id + valid_entries = 0 + for cache_id, cache_entry in zip(all_cache_ids, cache_data_list): if ( - isinstance(cache_entry, dict) + cache_entry is not None + and isinstance(cache_entry, dict) and cache_entry.get("cache_type") == "extract" and cache_entry.get("chunk_id") in chunk_ids ): chunk_id = cache_entry["chunk_id"] extraction_result = cache_entry["return"] - cached_results[chunk_id] = extraction_result + valid_entries += 1 - logger.debug( - f"Found {len(cached_results)} cached extraction results for {len(chunk_ids)} chunk IDs" + # Support multiple LLM caches per chunk + if chunk_id not in cached_results: + cached_results[chunk_id] = [] + cached_results[chunk_id].append(extraction_result) + + logger.info( + f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results" ) return cached_results async def _parse_extraction_result( - text_chunks: BaseKVStorage, extraction_result: str, chunk_id: str + text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str ) -> tuple[dict, dict]: """Parse cached extraction result using the same logic as extract_entities Args: + text_chunks_storage: Text chunks storage to get chunk data extraction_result: The cached LLM extraction result chunk_id: The chunk ID for source tracking @@ -431,8 +478,8 @@ async def _parse_extraction_result( Tuple of (entities_dict, relationships_dict) """ - # Get chunk data for file_path - chunk_data = await text_chunks.get_by_id(chunk_id) + # Get chunk data for file_path from storage + chunk_data = await text_chunks_storage.get_by_id(chunk_id) file_path = ( chunk_data.get("file_path", "unknown_source") if chunk_data @@ -805,8 +852,6 @@ async def _merge_nodes_then_upsert( entity_name, description, global_config, - pipeline_status, - pipeline_status_lock, llm_response_cache, ) else: @@ -969,8 +1014,6 @@ async def _merge_edges_then_upsert( f"({src_id}, {tgt_id})", description, global_config, - pipeline_status, - pipeline_status_lock, llm_response_cache, ) else: @@ -1146,6 +1189,7 @@ async def extract_entities( pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, + text_chunks_storage: BaseKVStorage | None = None, ) -> list: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -1252,6 +1296,9 @@ async def extract_entities( # Get file path from chunk data or use default file_path = chunk_dp.get("file_path", "unknown_source") + # Create cache keys collector for batch processing + cache_keys_collector = [] + # Get initial extraction hint_prompt = entity_extract_prompt.format( **{**context_base, "input_text": content} @@ -1263,7 +1310,10 @@ async def extract_entities( llm_response_cache=llm_response_cache, cache_type="extract", chunk_id=chunk_key, + cache_keys_collector=cache_keys_collector, ) + + # Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache) history = pack_user_ass_to_openai_messages(hint_prompt, final_result) # Process initial extraction with file path @@ -1280,6 +1330,7 @@ async def extract_entities( history_messages=history, cache_type="extract", chunk_id=chunk_key, + cache_keys_collector=cache_keys_collector, ) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) @@ -1310,11 +1361,21 @@ async def extract_entities( llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", + cache_keys_collector=cache_keys_collector, ) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": break + # Batch update chunk's llm_cache_list with all collected cache keys + if cache_keys_collector and text_chunks_storage: + await update_chunk_cache_list( + chunk_key, + text_chunks_storage, + cache_keys_collector, + "entity_extraction", + ) + processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) diff --git a/lightrag/utils.py b/lightrag/utils.py index 6c40407b..c6e2def9 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1423,6 +1423,48 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any return import_class +async def update_chunk_cache_list( + chunk_id: str, + text_chunks_storage: "BaseKVStorage", + cache_keys: list[str], + cache_scenario: str = "batch_update", +) -> None: + """Update chunk's llm_cache_list with the given cache keys + + Args: + chunk_id: Chunk identifier + text_chunks_storage: Text chunks storage instance + cache_keys: List of cache keys to add to the list + cache_scenario: Description of the cache scenario for logging + """ + if not cache_keys: + return + + try: + chunk_data = await text_chunks_storage.get_by_id(chunk_id) + if chunk_data: + # Ensure llm_cache_list exists + if "llm_cache_list" not in chunk_data: + chunk_data["llm_cache_list"] = [] + + # Add cache keys to the list if not already present + existing_keys = set(chunk_data["llm_cache_list"]) + new_keys = [key for key in cache_keys if key not in existing_keys] + + if new_keys: + chunk_data["llm_cache_list"].extend(new_keys) + + # Update the chunk in storage + await text_chunks_storage.upsert({chunk_id: chunk_data}) + logger.debug( + f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})" + ) + except Exception as e: + logger.warning( + f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}" + ) + + async def use_llm_func_with_cache( input_text: str, use_llm_func: callable, @@ -1431,6 +1473,7 @@ async def use_llm_func_with_cache( history_messages: list[dict[str, str]] = None, cache_type: str = "extract", chunk_id: str | None = None, + cache_keys_collector: list = None, ) -> str: """Call LLM function with cache support @@ -1445,6 +1488,8 @@ async def use_llm_func_with_cache( history_messages: History messages list cache_type: Type of cache chunk_id: Chunk identifier to store in cache + text_chunks_storage: Text chunks storage to update llm_cache_list + cache_keys_collector: Optional list to collect cache keys for batch processing Returns: LLM response text @@ -1457,6 +1502,9 @@ async def use_llm_func_with_cache( _prompt = input_text arg_hash = compute_args_hash(_prompt) + # Generate cache key for this LLM call + cache_key = generate_cache_key("default", cache_type, arg_hash) + cached_return, _1, _2, _3 = await handle_cache( llm_response_cache, arg_hash, @@ -1467,6 +1515,11 @@ async def use_llm_func_with_cache( if cached_return: logger.debug(f"Found cache for {arg_hash}") statistic_data["llm_cache"] += 1 + + # Add cache key to collector if provided + if cache_keys_collector is not None: + cache_keys_collector.append(cache_key) + return cached_return statistic_data["llm_call"] += 1 @@ -1491,6 +1544,10 @@ async def use_llm_func_with_cache( ), ) + # Add cache key to collector if provided + if cache_keys_collector is not None: + cache_keys_collector.append(cache_key) + return res # When cache is disabled, directly call LLM