feat: Implement map-reduce summarization to handle large humber of description merging

This commit is contained in:
yangdx 2025-08-25 21:03:16 +08:00
parent 0b1b264a5d
commit 882d6857d8

View file

@ -115,47 +115,152 @@ def chunking_by_token_size(
async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
description_list: list[str],
force_llm_summary_on_merge: int,
seperator: str,
global_config: dict,
llm_response_cache: BaseKVStorage | None = None,
) -> str:
"""Handle entity relation summary
For each entity or relation, input is the combined description of already existing description and new description.
If too long, use LLM to summarize.
"""Handle entity relation description summary using map-reduce approach.
This function summarizes a list of descriptions using a map-reduce strategy:
1. If total tokens <= summary_max_tokens, summarize directly
2. Otherwise, split descriptions into chunks that fit within token limits
3. Summarize each chunk, then recursively process the summaries
4. Continue until we get a final summary within token limits or num of descriptions is less than force_llm_summary_on_merge
Args:
entity_or_relation_name: Name of the entity or relation being summarized
description_list: List of description strings to summarize
global_config: Global configuration containing tokenizer and limits
llm_response_cache: Optional cache for LLM responses
Returns:
Final summarized description string
"""
# Handle empty input
if not description_list:
return ""
# If only one description, return it directly (no need for LLM call)
if len(description_list) == 1:
return description_list[0]
# Get configuration
tokenizer: Tokenizer = global_config["tokenizer"]
summary_max_tokens = global_config["summary_max_tokens"]
current_list = description_list[:] # Copy the list to avoid modifying original
# Iterative map-reduce process
while True:
# Calculate total tokens in current list
total_tokens = sum(len(tokenizer.encode(desc)) for desc in current_list)
# If total length is within limits, perform final summarization
if (
total_tokens <= summary_max_tokens
or len(current_list) < force_llm_summary_on_merge
):
if len(current_list) < force_llm_summary_on_merge:
# Already the final result
final_description = seperator.join(current_list)
return final_description if final_description else ""
else:
# Final summarization of remaining descriptions
return await _summarize_descriptions(
entity_or_relation_name,
current_list,
global_config,
llm_response_cache,
)
# Need to split into chunks - Map phase
chunks = []
current_chunk = []
current_tokens = 0
for desc in current_list:
desc_tokens = len(tokenizer.encode(desc))
# If adding current description would exceed limit, finalize current chunk
if current_tokens + desc_tokens > summary_max_tokens and current_chunk:
chunks.append(current_chunk)
current_chunk = [desc]
current_tokens = desc_tokens
else:
current_chunk.append(desc)
current_tokens += desc_tokens
# Add the last chunk if it exists
if current_chunk:
chunks.append(current_chunk)
logger.info(
f"Summarizing {entity_or_relation_name}: split {len(current_list)} descriptions into {len(chunks)} groups"
)
# Reduce phase: summarize each chunk
new_summaries = []
for chunk in chunks:
if len(chunk) == 1:
# Optimization: single description chunks don't need LLM summarization
new_summaries.append(chunk[0])
else:
# Multiple descriptions need LLM summarization
summary = await _summarize_descriptions(
entity_or_relation_name, chunk, global_config, llm_response_cache
)
new_summaries.append(summary)
# Update current list with new summaries for next iteration
current_list = new_summaries
async def _summarize_descriptions(
entity_or_relation_name: str,
description_list: list[str],
global_config: dict,
llm_response_cache: BaseKVStorage | None = None,
) -> str:
"""Helper function to summarize a list of descriptions using LLM.
Args:
entity_or_relation_name: Name of the entity or relation being summarized
descriptions: List of description strings to summarize
global_config: Global configuration containing LLM function and settings
llm_response_cache: Optional cache for LLM responses
Returns:
Summarized description string
"""
use_llm_func: callable = global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["summary_max_tokens"]
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
tokens = tokenizer.encode(description)
### summarize is not determined here anymore (It's determined by num_fragment now)
# if len(tokens) < summary_max_tokens: # No need for summary
# return description
prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = tokenizer.decode(tokens[:llm_max_tokens])
# Prepare context for the prompt
context_base = dict(
entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP),
description_list="\n".join(description_list),
language=language,
)
use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}")
logger.debug(
f"Summarizing {len(description_list)} descriptions for: {entity_or_relation_name}"
)
# Use LLM function with cache (higher priority for summary generation)
summary = await use_llm_func_with_cache(
use_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
# max_tokens=summary_max_tokens,
cache_type="extract",
)
return summary
@ -413,7 +518,7 @@ async def _rebuild_knowledge_from_chunks(
)
rebuilt_entities_count += 1
status_message = (
f"Rebuilt entity: {entity_name} from {len(chunk_ids)} chunks"
f"Entity `{entity_name}` rebuilt from {len(chunk_ids)} chunks"
)
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
@ -453,7 +558,7 @@ async def _rebuild_knowledge_from_chunks(
global_config=global_config,
)
rebuilt_relationships_count += 1
status_message = f"Rebuilt relationship: {src}->{tgt} from {len(chunk_ids)} chunks"
status_message = f"Relationship `{src}->{tgt}` rebuilt from {len(chunk_ids)} chunks"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
@ -736,21 +841,20 @@ async def _rebuild_single_entity(
edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP)
file_paths.update(edge_file_paths)
# Generate description from relationships or fallback to current
if relationship_descriptions:
combined_description = GRAPH_FIELD_SEP.join(relationship_descriptions)
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
num_fragment = combined_description.count(GRAPH_FIELD_SEP) + 1
# deduplicate descriptions
description_list = list(dict.fromkeys(relationship_descriptions))
if num_fragment >= force_llm_summary_on_merge:
final_description = await _handle_entity_relation_summary(
entity_name,
combined_description,
global_config,
llm_response_cache=llm_response_cache,
)
else:
final_description = combined_description
# Generate description from relationships or fallback to current
if description_list:
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
final_description = await _handle_entity_relation_summary(
entity_name,
description_list,
force_llm_summary_on_merge,
GRAPH_FIELD_SEP,
global_config,
llm_response_cache=llm_response_cache,
)
else:
final_description = current_entity.get("description", "")
@ -772,16 +876,9 @@ async def _rebuild_single_entity(
file_paths.add(entity_data["file_path"])
# Remove duplicates while preserving order
descriptions = list(dict.fromkeys(descriptions))
description_list = list(dict.fromkeys(descriptions))
entity_types = list(dict.fromkeys(entity_types))
# Combine all descriptions
combined_description = (
GRAPH_FIELD_SEP.join(descriptions)
if descriptions
else current_entity.get("description", "")
)
# Get most common entity type
entity_type = (
max(set(entity_types), key=entity_types.count)
@ -791,17 +888,17 @@ async def _rebuild_single_entity(
# Generate final description and update storage
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
num_fragment = combined_description.count(GRAPH_FIELD_SEP) + 1
if num_fragment >= force_llm_summary_on_merge:
if description_list:
final_description = await _handle_entity_relation_summary(
entity_name,
combined_description,
description_list,
force_llm_summary_on_merge,
GRAPH_FIELD_SEP,
global_config,
llm_response_cache=llm_response_cache,
)
else:
final_description = combined_description
final_description = current_entity.get("description", "")
await _update_entity_storage(final_description, entity_type, file_paths)
@ -859,45 +956,38 @@ async def _rebuild_single_relationship(
file_paths.add(rel_data["file_path"])
# Remove duplicates while preserving order
descriptions = list(dict.fromkeys(descriptions))
description_list = list(dict.fromkeys(descriptions))
keywords = list(dict.fromkeys(keywords))
# Combine descriptions and keywords (fallback to keep currunt unchanged)
combined_description = (
GRAPH_FIELD_SEP.join(descriptions)
if descriptions
else current_relationship.get("description", "")
)
combined_keywords = (
", ".join(set(keywords))
if keywords
else current_relationship.get("keywords", "")
)
# weight = (
# sum(weights) / len(weights)
# if weights
# else current_relationship.get("weight", 1.0)
# )
weight = sum(weights) if weights else current_relationship.get("weight", 1.0)
# Use summary if description has too many fragments
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
num_fragment = combined_description.count(GRAPH_FIELD_SEP) + 1
if num_fragment >= force_llm_summary_on_merge:
if description_list:
final_description = await _handle_entity_relation_summary(
f"{src}-{tgt}",
combined_description,
description_list,
force_llm_summary_on_merge,
GRAPH_FIELD_SEP,
global_config,
llm_response_cache=llm_response_cache,
)
else:
final_description = combined_description
# fallback to keep current(unchanged)
final_description = current_relationship.get("description", "")
# Update relationship in graph storage
updated_relationship_data = {
**current_relationship,
"description": final_description,
"description": final_description
if final_description
else current_relationship.get("description", ""),
"keywords": combined_keywords,
"weight": weight,
"source_id": GRAPH_FIELD_SEP.join(chunk_ids),
@ -971,21 +1061,16 @@ async def _merge_nodes_then_upsert(
reverse=True,
)[0][0]
description = GRAPH_FIELD_SEP.join(
already_description
+ list(
dict.fromkeys(
[dp["description"] for dp in nodes_data if dp.get("description")]
)
)
description_list = already_description + list(
dict.fromkeys([dp["description"] for dp in nodes_data if dp.get("description")])
)
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
num_fragment = description.count(GRAPH_FIELD_SEP) + 1
num_fragment = len(description_list)
already_fragment = already_description.count(GRAPH_FIELD_SEP) + 1
if num_fragment > 1:
if num_fragment > 0:
if num_fragment >= force_llm_summary_on_merge:
status_message = f"LLM merge N: {entity_name} | {already_fragment}+{num_fragment-already_fragment}"
status_message = f"LLM merging `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
@ -993,17 +1078,23 @@ async def _merge_nodes_then_upsert(
pipeline_status["history_messages"].append(status_message)
description = await _handle_entity_relation_summary(
entity_name,
description,
description_list,
force_llm_summary_on_merge,
GRAPH_FIELD_SEP,
global_config,
llm_response_cache,
)
else:
status_message = f"Merge N: {entity_name} | {already_fragment}+{num_fragment-already_fragment}"
status_message = f"Merging `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
description = GRAPH_FIELD_SEP.join(description_list)
else:
logger.error(f"Entity {entity_name} has no description")
description = "(no description)"
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
@ -1084,21 +1175,16 @@ async def _merge_edges_then_upsert(
# Process edges_data with None checks
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join(
already_description
+ list(
dict.fromkeys(
[dp["description"] for dp in edges_data if dp.get("description")]
)
)
description_list = already_description + list(
dict.fromkeys([dp["description"] for dp in edges_data if dp.get("description")])
)
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
num_fragment = description.count(GRAPH_FIELD_SEP) + 1
num_fragment = len(description_list)
already_fragment = already_description.count(GRAPH_FIELD_SEP) + 1
if num_fragment > 1:
if num_fragment > 0:
if num_fragment >= force_llm_summary_on_merge:
status_message = f"LLM merge E: {src_id} - {tgt_id} | {already_fragment}+{num_fragment-already_fragment}"
status_message = f"LLM merging `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
@ -1106,17 +1192,23 @@ async def _merge_edges_then_upsert(
pipeline_status["history_messages"].append(status_message)
description = await _handle_entity_relation_summary(
f"({src_id}, {tgt_id})",
description,
description_list,
force_llm_summary_on_merge,
GRAPH_FIELD_SEP,
global_config,
llm_response_cache,
)
else:
status_message = f"Merge E: {src_id} - {tgt_id} | {already_fragment}+{num_fragment-already_fragment}"
status_message = f"Merging `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
description = GRAPH_FIELD_SEP.join(description_list)
else:
logger.error(f"Edge {src_id} - {tgt_id} has no description")
description = "(no description)"
# Split all existing and new keywords into individual terms, then combine and deduplicate
all_keywords = set()