Refac: Optimize document deletion performance

- Adding chunks_list to  dock_status
- Adding  llm_cache_list to text_chunks
- Implemented storage types: JsonKV and  Redis
This commit is contained in:
yangdx 2025-07-03 04:18:25 +08:00
parent d0f04383cc
commit e56734cb8b
7 changed files with 208 additions and 71 deletions

View file

@ -634,6 +634,8 @@ class DocProcessingStatus:
"""ISO format timestamp when document was last updated""" """ISO format timestamp when document was last updated"""
chunks_count: int | None = None chunks_count: int | None = None
"""Number of chunks after splitting, used for processing""" """Number of chunks after splitting, used for processing"""
chunks_list: list[str] | None = field(default_factory=list)
"""List of chunk IDs associated with this document, used for deletion"""
error: str | None = None error: str | None = None
"""Error message if failed""" """Error message if failed"""
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)

View file

@ -118,6 +118,10 @@ class JsonDocStatusStorage(DocStatusStorage):
return return
logger.debug(f"Inserting {len(data)} records to {self.namespace}") logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock: async with self._storage_lock:
# Ensure chunks_list field exists for new documents
for doc_id, doc_data in data.items():
if "chunks_list" not in doc_data:
doc_data["chunks_list"] = []
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.namespace) await set_all_update_flags(self.namespace)

View file

@ -109,6 +109,11 @@ class JsonKVStorage(BaseKVStorage):
return return
logger.debug(f"Inserting {len(data)} records to {self.namespace}") logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock: async with self._storage_lock:
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
for chunk_id, chunk_data in data.items():
if "llm_cache_list" not in chunk_data:
chunk_data["llm_cache_list"] = []
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.namespace) await set_all_update_flags(self.namespace)

View file

@ -202,6 +202,12 @@ class RedisKVStorage(BaseKVStorage):
return return
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
for chunk_id, chunk_data in data.items():
if "llm_cache_list" not in chunk_data:
chunk_data["llm_cache_list"] = []
pipe = redis.pipeline() pipe = redis.pipeline()
for k, v in data.items(): for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v)) pipe.set(f"{self.namespace}:{k}", json.dumps(v))
@ -601,6 +607,11 @@ class RedisDocStatusStorage(DocStatusStorage):
logger.debug(f"Inserting {len(data)} records to {self.namespace}") logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
# Ensure chunks_list field exists for new documents
for doc_id, doc_data in data.items():
if "chunks_list" not in doc_data:
doc_data["chunks_list"] = []
pipe = redis.pipeline() pipe = redis.pipeline()
for k, v in data.items(): for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v)) pipe.set(f"{self.namespace}:{k}", json.dumps(v))

View file

@ -349,6 +349,7 @@ class LightRAG:
# Fix global_config now # Fix global_config now
global_config = asdict(self) global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@ -952,6 +953,7 @@ class LightRAG:
**dp, **dp,
"full_doc_id": doc_id, "full_doc_id": doc_id,
"file_path": file_path, # Add file path to each chunk "file_path": file_path, # Add file path to each chunk
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
} }
for dp in self.chunking_func( for dp in self.chunking_func(
self.tokenizer, self.tokenizer,
@ -963,14 +965,17 @@ class LightRAG:
) )
} }
# Process document (text chunks and full docs) in parallel # Process document in two stages
# Create tasks with references for potential cancellation # Stage 1: Process text chunks and docs (parallel execution)
doc_status_task = asyncio.create_task( doc_status_task = asyncio.create_task(
self.doc_status.upsert( self.doc_status.upsert(
{ {
doc_id: { doc_id: {
"status": DocStatus.PROCESSING, "status": DocStatus.PROCESSING,
"chunks_count": len(chunks), "chunks_count": len(chunks),
"chunks_list": list(
chunks.keys()
), # Save chunks list
"content": status_doc.content, "content": status_doc.content,
"content_summary": status_doc.content_summary, "content_summary": status_doc.content_summary,
"content_length": status_doc.content_length, "content_length": status_doc.content_length,
@ -986,11 +991,6 @@ class LightRAG:
chunks_vdb_task = asyncio.create_task( chunks_vdb_task = asyncio.create_task(
self.chunks_vdb.upsert(chunks) self.chunks_vdb.upsert(chunks)
) )
entity_relation_task = asyncio.create_task(
self._process_entity_relation_graph(
chunks, pipeline_status, pipeline_status_lock
)
)
full_docs_task = asyncio.create_task( full_docs_task = asyncio.create_task(
self.full_docs.upsert( self.full_docs.upsert(
{doc_id: {"content": status_doc.content}} {doc_id: {"content": status_doc.content}}
@ -999,14 +999,26 @@ class LightRAG:
text_chunks_task = asyncio.create_task( text_chunks_task = asyncio.create_task(
self.text_chunks.upsert(chunks) self.text_chunks.upsert(chunks)
) )
tasks = [
# First stage tasks (parallel execution)
first_stage_tasks = [
doc_status_task, doc_status_task,
chunks_vdb_task, chunks_vdb_task,
entity_relation_task,
full_docs_task, full_docs_task,
text_chunks_task, text_chunks_task,
] ]
await asyncio.gather(*tasks) entity_relation_task = None
# Execute first stage tasks
await asyncio.gather(*first_stage_tasks)
# Stage 2: Process entity relation graph (after text_chunks are saved)
entity_relation_task = asyncio.create_task(
self._process_entity_relation_graph(
chunks, pipeline_status, pipeline_status_lock
)
)
await entity_relation_task
file_extraction_stage_ok = True file_extraction_stage_ok = True
except Exception as e: except Exception as e:
@ -1021,14 +1033,14 @@ class LightRAG:
) )
pipeline_status["history_messages"].append(error_msg) pipeline_status["history_messages"].append(error_msg)
# Cancel other tasks as they are no longer meaningful # Cancel tasks that are not yet completed
for task in [ all_tasks = first_stage_tasks + (
chunks_vdb_task, [entity_relation_task]
entity_relation_task, if entity_relation_task
full_docs_task, else []
text_chunks_task, )
]: for task in all_tasks:
if not task.done(): if task and not task.done():
task.cancel() task.cancel()
# Persistent llm cache # Persistent llm cache
@ -1078,6 +1090,9 @@ class LightRAG:
doc_id: { doc_id: {
"status": DocStatus.PROCESSED, "status": DocStatus.PROCESSED,
"chunks_count": len(chunks), "chunks_count": len(chunks),
"chunks_list": list(
chunks.keys()
), # 保留 chunks_list
"content": status_doc.content, "content": status_doc.content,
"content_summary": status_doc.content_summary, "content_summary": status_doc.content_summary,
"content_length": status_doc.content_length, "content_length": status_doc.content_length,
@ -1196,6 +1211,7 @@ class LightRAG:
pipeline_status=pipeline_status, pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock, pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache, llm_response_cache=self.llm_response_cache,
text_chunks_storage=self.text_chunks,
) )
return chunk_results return chunk_results
except Exception as e: except Exception as e:
@ -1726,28 +1742,10 @@ class LightRAG:
file_path="", file_path="",
) )
# 2. Get all chunks related to this document # 2. Get chunk IDs from document status
try: chunk_ids = set(doc_status_data.get("chunks_list", []))
all_chunks = await self.text_chunks.get_all()
related_chunks = {
chunk_id: chunk_data
for chunk_id, chunk_data in all_chunks.items()
if isinstance(chunk_data, dict)
and chunk_data.get("full_doc_id") == doc_id
}
# Update pipeline status after getting chunks count if not chunk_ids:
async with pipeline_status_lock:
log_message = f"Retrieved {len(related_chunks)} of {len(all_chunks)} related chunks"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
logger.error(f"Failed to retrieve chunks for document {doc_id}: {e}")
raise Exception(f"Failed to retrieve document chunks: {e}") from e
if not related_chunks:
logger.warning(f"No chunks found for document {doc_id}") logger.warning(f"No chunks found for document {doc_id}")
# Mark that deletion operations have started # Mark that deletion operations have started
deletion_operations_started = True deletion_operations_started = True
@ -1778,7 +1776,6 @@ class LightRAG:
file_path=file_path, file_path=file_path,
) )
chunk_ids = set(related_chunks.keys())
# Mark that deletion operations have started # Mark that deletion operations have started
deletion_operations_started = True deletion_operations_started = True
@ -1943,7 +1940,7 @@ class LightRAG:
knowledge_graph_inst=self.chunk_entity_relation_graph, knowledge_graph_inst=self.chunk_entity_relation_graph,
entities_vdb=self.entities_vdb, entities_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb, relationships_vdb=self.relationships_vdb,
text_chunks=self.text_chunks, text_chunks_storage=self.text_chunks,
llm_response_cache=self.llm_response_cache, llm_response_cache=self.llm_response_cache,
global_config=asdict(self), global_config=asdict(self),
pipeline_status=pipeline_status, pipeline_status=pipeline_status,

View file

@ -25,6 +25,7 @@ from .utils import (
CacheData, CacheData,
get_conversation_turns, get_conversation_turns,
use_llm_func_with_cache, use_llm_func_with_cache,
update_chunk_cache_list,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@ -103,8 +104,6 @@ async def _handle_entity_relation_summary(
entity_or_relation_name: str, entity_or_relation_name: str,
description: str, description: str,
global_config: dict, global_config: dict,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
) -> str: ) -> str:
"""Handle entity relation summary """Handle entity relation summary
@ -247,7 +246,7 @@ async def _rebuild_knowledge_from_chunks(
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks: BaseKVStorage, text_chunks_storage: BaseKVStorage,
llm_response_cache: BaseKVStorage, llm_response_cache: BaseKVStorage,
global_config: dict[str, str], global_config: dict[str, str],
pipeline_status: dict | None = None, pipeline_status: dict | None = None,
@ -261,6 +260,7 @@ async def _rebuild_knowledge_from_chunks(
Args: Args:
entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids
relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids
text_chunks_data: Pre-loaded chunk data dict {chunk_id: chunk_data}
""" """
if not entities_to_rebuild and not relationships_to_rebuild: if not entities_to_rebuild and not relationships_to_rebuild:
return return
@ -273,6 +273,8 @@ async def _rebuild_knowledge_from_chunks(
all_referenced_chunk_ids.update(chunk_ids) all_referenced_chunk_ids.update(chunk_ids)
for chunk_ids in relationships_to_rebuild.values(): for chunk_ids in relationships_to_rebuild.values():
all_referenced_chunk_ids.update(chunk_ids) all_referenced_chunk_ids.update(chunk_ids)
# sort all_referenced_chunk_ids to get a stable order in merge stage
all_referenced_chunk_ids = sorted(all_referenced_chunk_ids)
status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions" status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
logger.info(status_message) logger.info(status_message)
@ -281,9 +283,11 @@ async def _rebuild_knowledge_from_chunks(
pipeline_status["latest_message"] = status_message pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message) pipeline_status["history_messages"].append(status_message)
# Get cached extraction results for these chunks # Get cached extraction results for these chunks using storage
cached_results = await _get_cached_extraction_results( cached_results = await _get_cached_extraction_results(
llm_response_cache, all_referenced_chunk_ids llm_response_cache,
all_referenced_chunk_ids,
text_chunks_storage=text_chunks_storage,
) )
if not cached_results: if not cached_results:
@ -299,15 +303,25 @@ async def _rebuild_knowledge_from_chunks(
chunk_entities = {} # chunk_id -> {entity_name: [entity_data]} chunk_entities = {} # chunk_id -> {entity_name: [entity_data]}
chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]} chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]}
for chunk_id, extraction_result in cached_results.items(): for chunk_id, extraction_results in cached_results.items():
try: try:
entities, relationships = await _parse_extraction_result( # Handle multiple extraction results per chunk
text_chunks=text_chunks, chunk_entities[chunk_id] = defaultdict(list)
extraction_result=extraction_result, chunk_relationships[chunk_id] = defaultdict(list)
chunk_id=chunk_id,
) for extraction_result in extraction_results:
chunk_entities[chunk_id] = entities entities, relationships = await _parse_extraction_result(
chunk_relationships[chunk_id] = relationships text_chunks_storage=text_chunks_storage,
extraction_result=extraction_result,
chunk_id=chunk_id,
)
# Merge entities and relationships from this extraction result
for entity_name, entity_list in entities.items():
chunk_entities[chunk_id][entity_name].extend(entity_list)
for rel_key, rel_list in relationships.items():
chunk_relationships[chunk_id][rel_key].extend(rel_list)
except Exception as e: except Exception as e:
status_message = ( status_message = (
f"Failed to parse cached extraction result for chunk {chunk_id}: {e}" f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
@ -387,43 +401,76 @@ async def _rebuild_knowledge_from_chunks(
async def _get_cached_extraction_results( async def _get_cached_extraction_results(
llm_response_cache: BaseKVStorage, chunk_ids: set[str] llm_response_cache: BaseKVStorage,
) -> dict[str, str]: chunk_ids: set[str],
text_chunks_storage: BaseKVStorage,
) -> dict[str, list[str]]:
"""Get cached extraction results for specific chunk IDs """Get cached extraction results for specific chunk IDs
Args: Args:
llm_response_cache: LLM response cache storage
chunk_ids: Set of chunk IDs to get cached results for chunk_ids: Set of chunk IDs to get cached results for
text_chunks_data: Pre-loaded chunk data (optional, for performance)
text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None)
Returns: Returns:
Dict mapping chunk_id -> extraction_result_text Dict mapping chunk_id -> list of extraction_result_text
""" """
cached_results = {} cached_results = {}
# Get all cached data (flattened cache structure) # Collect all LLM cache IDs from chunks
all_cache = await llm_response_cache.get_all() all_cache_ids = set()
for cache_key, cache_entry in all_cache.items(): # Read from storage
chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids))
for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list):
if chunk_data and isinstance(chunk_data, dict):
llm_cache_list = chunk_data.get("llm_cache_list", [])
if llm_cache_list:
all_cache_ids.update(llm_cache_list)
else:
logger.warning(
f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}"
)
if not all_cache_ids:
logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs")
return cached_results
# Batch get LLM cache entries
cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids))
# Process cache entries and group by chunk_id
valid_entries = 0
for cache_id, cache_entry in zip(all_cache_ids, cache_data_list):
if ( if (
isinstance(cache_entry, dict) cache_entry is not None
and isinstance(cache_entry, dict)
and cache_entry.get("cache_type") == "extract" and cache_entry.get("cache_type") == "extract"
and cache_entry.get("chunk_id") in chunk_ids and cache_entry.get("chunk_id") in chunk_ids
): ):
chunk_id = cache_entry["chunk_id"] chunk_id = cache_entry["chunk_id"]
extraction_result = cache_entry["return"] extraction_result = cache_entry["return"]
cached_results[chunk_id] = extraction_result valid_entries += 1
logger.debug( # Support multiple LLM caches per chunk
f"Found {len(cached_results)} cached extraction results for {len(chunk_ids)} chunk IDs" if chunk_id not in cached_results:
cached_results[chunk_id] = []
cached_results[chunk_id].append(extraction_result)
logger.info(
f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"
) )
return cached_results return cached_results
async def _parse_extraction_result( async def _parse_extraction_result(
text_chunks: BaseKVStorage, extraction_result: str, chunk_id: str text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
"""Parse cached extraction result using the same logic as extract_entities """Parse cached extraction result using the same logic as extract_entities
Args: Args:
text_chunks_storage: Text chunks storage to get chunk data
extraction_result: The cached LLM extraction result extraction_result: The cached LLM extraction result
chunk_id: The chunk ID for source tracking chunk_id: The chunk ID for source tracking
@ -431,8 +478,8 @@ async def _parse_extraction_result(
Tuple of (entities_dict, relationships_dict) Tuple of (entities_dict, relationships_dict)
""" """
# Get chunk data for file_path # Get chunk data for file_path from storage
chunk_data = await text_chunks.get_by_id(chunk_id) chunk_data = await text_chunks_storage.get_by_id(chunk_id)
file_path = ( file_path = (
chunk_data.get("file_path", "unknown_source") chunk_data.get("file_path", "unknown_source")
if chunk_data if chunk_data
@ -805,8 +852,6 @@ async def _merge_nodes_then_upsert(
entity_name, entity_name,
description, description,
global_config, global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache, llm_response_cache,
) )
else: else:
@ -969,8 +1014,6 @@ async def _merge_edges_then_upsert(
f"({src_id}, {tgt_id})", f"({src_id}, {tgt_id})",
description, description,
global_config, global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache, llm_response_cache,
) )
else: else:
@ -1146,6 +1189,7 @@ async def extract_entities(
pipeline_status: dict = None, pipeline_status: dict = None,
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
text_chunks_storage: BaseKVStorage | None = None,
) -> list: ) -> list:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@ -1252,6 +1296,9 @@ async def extract_entities(
# Get file path from chunk data or use default # Get file path from chunk data or use default
file_path = chunk_dp.get("file_path", "unknown_source") file_path = chunk_dp.get("file_path", "unknown_source")
# Create cache keys collector for batch processing
cache_keys_collector = []
# Get initial extraction # Get initial extraction
hint_prompt = entity_extract_prompt.format( hint_prompt = entity_extract_prompt.format(
**{**context_base, "input_text": content} **{**context_base, "input_text": content}
@ -1263,7 +1310,10 @@ async def extract_entities(
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
cache_type="extract", cache_type="extract",
chunk_id=chunk_key, chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
) )
# Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result) history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
# Process initial extraction with file path # Process initial extraction with file path
@ -1280,6 +1330,7 @@ async def extract_entities(
history_messages=history, history_messages=history,
cache_type="extract", cache_type="extract",
chunk_id=chunk_key, chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
) )
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
@ -1310,11 +1361,21 @@ async def extract_entities(
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
history_messages=history, history_messages=history,
cache_type="extract", cache_type="extract",
cache_keys_collector=cache_keys_collector,
) )
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes": if if_loop_result != "yes":
break break
# Batch update chunk's llm_cache_list with all collected cache keys
if cache_keys_collector and text_chunks_storage:
await update_chunk_cache_list(
chunk_key,
text_chunks_storage,
cache_keys_collector,
"entity_extraction",
)
processed_chunks += 1 processed_chunks += 1
entities_count = len(maybe_nodes) entities_count = len(maybe_nodes)
relations_count = len(maybe_edges) relations_count = len(maybe_edges)

View file

@ -1423,6 +1423,48 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
return import_class return import_class
async def update_chunk_cache_list(
chunk_id: str,
text_chunks_storage: "BaseKVStorage",
cache_keys: list[str],
cache_scenario: str = "batch_update",
) -> None:
"""Update chunk's llm_cache_list with the given cache keys
Args:
chunk_id: Chunk identifier
text_chunks_storage: Text chunks storage instance
cache_keys: List of cache keys to add to the list
cache_scenario: Description of the cache scenario for logging
"""
if not cache_keys:
return
try:
chunk_data = await text_chunks_storage.get_by_id(chunk_id)
if chunk_data:
# Ensure llm_cache_list exists
if "llm_cache_list" not in chunk_data:
chunk_data["llm_cache_list"] = []
# Add cache keys to the list if not already present
existing_keys = set(chunk_data["llm_cache_list"])
new_keys = [key for key in cache_keys if key not in existing_keys]
if new_keys:
chunk_data["llm_cache_list"].extend(new_keys)
# Update the chunk in storage
await text_chunks_storage.upsert({chunk_id: chunk_data})
logger.debug(
f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})"
)
except Exception as e:
logger.warning(
f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}"
)
async def use_llm_func_with_cache( async def use_llm_func_with_cache(
input_text: str, input_text: str,
use_llm_func: callable, use_llm_func: callable,
@ -1431,6 +1473,7 @@ async def use_llm_func_with_cache(
history_messages: list[dict[str, str]] = None, history_messages: list[dict[str, str]] = None,
cache_type: str = "extract", cache_type: str = "extract",
chunk_id: str | None = None, chunk_id: str | None = None,
cache_keys_collector: list = None,
) -> str: ) -> str:
"""Call LLM function with cache support """Call LLM function with cache support
@ -1445,6 +1488,8 @@ async def use_llm_func_with_cache(
history_messages: History messages list history_messages: History messages list
cache_type: Type of cache cache_type: Type of cache
chunk_id: Chunk identifier to store in cache chunk_id: Chunk identifier to store in cache
text_chunks_storage: Text chunks storage to update llm_cache_list
cache_keys_collector: Optional list to collect cache keys for batch processing
Returns: Returns:
LLM response text LLM response text
@ -1457,6 +1502,9 @@ async def use_llm_func_with_cache(
_prompt = input_text _prompt = input_text
arg_hash = compute_args_hash(_prompt) arg_hash = compute_args_hash(_prompt)
# Generate cache key for this LLM call
cache_key = generate_cache_key("default", cache_type, arg_hash)
cached_return, _1, _2, _3 = await handle_cache( cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache, llm_response_cache,
arg_hash, arg_hash,
@ -1467,6 +1515,11 @@ async def use_llm_func_with_cache(
if cached_return: if cached_return:
logger.debug(f"Found cache for {arg_hash}") logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1 statistic_data["llm_cache"] += 1
# Add cache key to collector if provided
if cache_keys_collector is not None:
cache_keys_collector.append(cache_key)
return cached_return return cached_return
statistic_data["llm_call"] += 1 statistic_data["llm_call"] += 1
@ -1491,6 +1544,10 @@ async def use_llm_func_with_cache(
), ),
) )
# Add cache key to collector if provided
if cache_keys_collector is not None:
cache_keys_collector.append(cache_key)
return res return res
# When cache is disabled, directly call LLM # When cache is disabled, directly call LLM