feat: Implement map-reduce summarization to handle large humber of description merging

This commit is contained in:
yangdx 2025-08-25 21:03:16 +08:00
parent 0b1b264a5d
commit 882d6857d8

View file

@ -115,47 +115,152 @@ def chunking_by_token_size(
async def _handle_entity_relation_summary( async def _handle_entity_relation_summary(
entity_or_relation_name: str, entity_or_relation_name: str,
description: str, description_list: list[str],
force_llm_summary_on_merge: int,
seperator: str,
global_config: dict, global_config: dict,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
) -> str: ) -> str:
"""Handle entity relation summary """Handle entity relation description summary using map-reduce approach.
For each entity or relation, input is the combined description of already existing description and new description.
If too long, use LLM to summarize. This function summarizes a list of descriptions using a map-reduce strategy:
1. If total tokens <= summary_max_tokens, summarize directly
2. Otherwise, split descriptions into chunks that fit within token limits
3. Summarize each chunk, then recursively process the summaries
4. Continue until we get a final summary within token limits or num of descriptions is less than force_llm_summary_on_merge
Args:
entity_or_relation_name: Name of the entity or relation being summarized
description_list: List of description strings to summarize
global_config: Global configuration containing tokenizer and limits
llm_response_cache: Optional cache for LLM responses
Returns:
Final summarized description string
"""
# Handle empty input
if not description_list:
return ""
# If only one description, return it directly (no need for LLM call)
if len(description_list) == 1:
return description_list[0]
# Get configuration
tokenizer: Tokenizer = global_config["tokenizer"]
summary_max_tokens = global_config["summary_max_tokens"]
current_list = description_list[:] # Copy the list to avoid modifying original
# Iterative map-reduce process
while True:
# Calculate total tokens in current list
total_tokens = sum(len(tokenizer.encode(desc)) for desc in current_list)
# If total length is within limits, perform final summarization
if (
total_tokens <= summary_max_tokens
or len(current_list) < force_llm_summary_on_merge
):
if len(current_list) < force_llm_summary_on_merge:
# Already the final result
final_description = seperator.join(current_list)
return final_description if final_description else ""
else:
# Final summarization of remaining descriptions
return await _summarize_descriptions(
entity_or_relation_name,
current_list,
global_config,
llm_response_cache,
)
# Need to split into chunks - Map phase
chunks = []
current_chunk = []
current_tokens = 0
for desc in current_list:
desc_tokens = len(tokenizer.encode(desc))
# If adding current description would exceed limit, finalize current chunk
if current_tokens + desc_tokens > summary_max_tokens and current_chunk:
chunks.append(current_chunk)
current_chunk = [desc]
current_tokens = desc_tokens
else:
current_chunk.append(desc)
current_tokens += desc_tokens
# Add the last chunk if it exists
if current_chunk:
chunks.append(current_chunk)
logger.info(
f"Summarizing {entity_or_relation_name}: split {len(current_list)} descriptions into {len(chunks)} groups"
)
# Reduce phase: summarize each chunk
new_summaries = []
for chunk in chunks:
if len(chunk) == 1:
# Optimization: single description chunks don't need LLM summarization
new_summaries.append(chunk[0])
else:
# Multiple descriptions need LLM summarization
summary = await _summarize_descriptions(
entity_or_relation_name, chunk, global_config, llm_response_cache
)
new_summaries.append(summary)
# Update current list with new summaries for next iteration
current_list = new_summaries
async def _summarize_descriptions(
entity_or_relation_name: str,
description_list: list[str],
global_config: dict,
llm_response_cache: BaseKVStorage | None = None,
) -> str:
"""Helper function to summarize a list of descriptions using LLM.
Args:
entity_or_relation_name: Name of the entity or relation being summarized
descriptions: List of description strings to summarize
global_config: Global configuration containing LLM function and settings
llm_response_cache: Optional cache for LLM responses
Returns:
Summarized description string
""" """
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks # Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8) use_llm_func = partial(use_llm_func, _priority=8)
tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["summary_max_tokens"]
language = global_config["addon_params"].get( language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"] "language", PROMPTS["DEFAULT_LANGUAGE"]
) )
tokens = tokenizer.encode(description)
### summarize is not determined here anymore (It's determined by num_fragment now)
# if len(tokens) < summary_max_tokens: # No need for summary
# return description
prompt_template = PROMPTS["summarize_entity_descriptions"] prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = tokenizer.decode(tokens[:llm_max_tokens])
# Prepare context for the prompt
context_base = dict( context_base = dict(
entity_name=entity_or_relation_name, entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP), description_list="\n".join(description_list),
language=language, language=language,
) )
use_prompt = prompt_template.format(**context_base) use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}")
logger.debug(
f"Summarizing {len(description_list)} descriptions for: {entity_or_relation_name}"
)
# Use LLM function with cache (higher priority for summary generation) # Use LLM function with cache (higher priority for summary generation)
summary = await use_llm_func_with_cache( summary = await use_llm_func_with_cache(
use_prompt, use_prompt,
use_llm_func, use_llm_func,
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
# max_tokens=summary_max_tokens,
cache_type="extract", cache_type="extract",
) )
return summary return summary
@ -413,7 +518,7 @@ async def _rebuild_knowledge_from_chunks(
) )
rebuilt_entities_count += 1 rebuilt_entities_count += 1
status_message = ( status_message = (
f"Rebuilt entity: {entity_name} from {len(chunk_ids)} chunks" f"Entity `{entity_name}` rebuilt from {len(chunk_ids)} chunks"
) )
logger.info(status_message) logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None: if pipeline_status is not None and pipeline_status_lock is not None:
@ -453,7 +558,7 @@ async def _rebuild_knowledge_from_chunks(
global_config=global_config, global_config=global_config,
) )
rebuilt_relationships_count += 1 rebuilt_relationships_count += 1
status_message = f"Rebuilt relationship: {src}->{tgt} from {len(chunk_ids)} chunks" status_message = f"Relationship `{src}->{tgt}` rebuilt from {len(chunk_ids)} chunks"
logger.info(status_message) logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None: if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock: async with pipeline_status_lock:
@ -736,21 +841,20 @@ async def _rebuild_single_entity(
edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP) edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP)
file_paths.update(edge_file_paths) file_paths.update(edge_file_paths)
# Generate description from relationships or fallback to current # deduplicate descriptions
if relationship_descriptions: description_list = list(dict.fromkeys(relationship_descriptions))
combined_description = GRAPH_FIELD_SEP.join(relationship_descriptions)
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
num_fragment = combined_description.count(GRAPH_FIELD_SEP) + 1
if num_fragment >= force_llm_summary_on_merge: # Generate description from relationships or fallback to current
final_description = await _handle_entity_relation_summary( if description_list:
entity_name, force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
combined_description, final_description = await _handle_entity_relation_summary(
global_config, entity_name,
llm_response_cache=llm_response_cache, description_list,
) force_llm_summary_on_merge,
else: GRAPH_FIELD_SEP,
final_description = combined_description global_config,
llm_response_cache=llm_response_cache,
)
else: else:
final_description = current_entity.get("description", "") final_description = current_entity.get("description", "")
@ -772,16 +876,9 @@ async def _rebuild_single_entity(
file_paths.add(entity_data["file_path"]) file_paths.add(entity_data["file_path"])
# Remove duplicates while preserving order # Remove duplicates while preserving order
descriptions = list(dict.fromkeys(descriptions)) description_list = list(dict.fromkeys(descriptions))
entity_types = list(dict.fromkeys(entity_types)) entity_types = list(dict.fromkeys(entity_types))
# Combine all descriptions
combined_description = (
GRAPH_FIELD_SEP.join(descriptions)
if descriptions
else current_entity.get("description", "")
)
# Get most common entity type # Get most common entity type
entity_type = ( entity_type = (
max(set(entity_types), key=entity_types.count) max(set(entity_types), key=entity_types.count)
@ -791,17 +888,17 @@ async def _rebuild_single_entity(
# Generate final description and update storage # Generate final description and update storage
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
num_fragment = combined_description.count(GRAPH_FIELD_SEP) + 1 if description_list:
if num_fragment >= force_llm_summary_on_merge:
final_description = await _handle_entity_relation_summary( final_description = await _handle_entity_relation_summary(
entity_name, entity_name,
combined_description, description_list,
force_llm_summary_on_merge,
GRAPH_FIELD_SEP,
global_config, global_config,
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
) )
else: else:
final_description = combined_description final_description = current_entity.get("description", "")
await _update_entity_storage(final_description, entity_type, file_paths) await _update_entity_storage(final_description, entity_type, file_paths)
@ -859,45 +956,38 @@ async def _rebuild_single_relationship(
file_paths.add(rel_data["file_path"]) file_paths.add(rel_data["file_path"])
# Remove duplicates while preserving order # Remove duplicates while preserving order
descriptions = list(dict.fromkeys(descriptions)) description_list = list(dict.fromkeys(descriptions))
keywords = list(dict.fromkeys(keywords)) keywords = list(dict.fromkeys(keywords))
# Combine descriptions and keywords (fallback to keep currunt unchanged)
combined_description = (
GRAPH_FIELD_SEP.join(descriptions)
if descriptions
else current_relationship.get("description", "")
)
combined_keywords = ( combined_keywords = (
", ".join(set(keywords)) ", ".join(set(keywords))
if keywords if keywords
else current_relationship.get("keywords", "") else current_relationship.get("keywords", "")
) )
# weight = (
# sum(weights) / len(weights)
# if weights
# else current_relationship.get("weight", 1.0)
# )
weight = sum(weights) if weights else current_relationship.get("weight", 1.0) weight = sum(weights) if weights else current_relationship.get("weight", 1.0)
# Use summary if description has too many fragments # Use summary if description has too many fragments
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
num_fragment = combined_description.count(GRAPH_FIELD_SEP) + 1 if description_list:
if num_fragment >= force_llm_summary_on_merge:
final_description = await _handle_entity_relation_summary( final_description = await _handle_entity_relation_summary(
f"{src}-{tgt}", f"{src}-{tgt}",
combined_description, description_list,
force_llm_summary_on_merge,
GRAPH_FIELD_SEP,
global_config, global_config,
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
) )
else: else:
final_description = combined_description # fallback to keep current(unchanged)
final_description = current_relationship.get("description", "")
# Update relationship in graph storage # Update relationship in graph storage
updated_relationship_data = { updated_relationship_data = {
**current_relationship, **current_relationship,
"description": final_description, "description": final_description
if final_description
else current_relationship.get("description", ""),
"keywords": combined_keywords, "keywords": combined_keywords,
"weight": weight, "weight": weight,
"source_id": GRAPH_FIELD_SEP.join(chunk_ids), "source_id": GRAPH_FIELD_SEP.join(chunk_ids),
@ -971,21 +1061,16 @@ async def _merge_nodes_then_upsert(
reverse=True, reverse=True,
)[0][0] )[0][0]
description = GRAPH_FIELD_SEP.join( description_list = already_description + list(
already_description dict.fromkeys([dp["description"] for dp in nodes_data if dp.get("description")])
+ list(
dict.fromkeys(
[dp["description"] for dp in nodes_data if dp.get("description")]
)
)
) )
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
num_fragment = description.count(GRAPH_FIELD_SEP) + 1 num_fragment = len(description_list)
already_fragment = already_description.count(GRAPH_FIELD_SEP) + 1 already_fragment = already_description.count(GRAPH_FIELD_SEP) + 1
if num_fragment > 1: if num_fragment > 0:
if num_fragment >= force_llm_summary_on_merge: if num_fragment >= force_llm_summary_on_merge:
status_message = f"LLM merge N: {entity_name} | {already_fragment}+{num_fragment-already_fragment}" status_message = f"LLM merging `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}"
logger.info(status_message) logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None: if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock: async with pipeline_status_lock:
@ -993,17 +1078,23 @@ async def _merge_nodes_then_upsert(
pipeline_status["history_messages"].append(status_message) pipeline_status["history_messages"].append(status_message)
description = await _handle_entity_relation_summary( description = await _handle_entity_relation_summary(
entity_name, entity_name,
description, description_list,
force_llm_summary_on_merge,
GRAPH_FIELD_SEP,
global_config, global_config,
llm_response_cache, llm_response_cache,
) )
else: else:
status_message = f"Merge N: {entity_name} | {already_fragment}+{num_fragment-already_fragment}" status_message = f"Merging `{entity_name}` | {already_fragment}+{num_fragment-already_fragment}"
logger.info(status_message) logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None: if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock: async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message) pipeline_status["history_messages"].append(status_message)
description = GRAPH_FIELD_SEP.join(description_list)
else:
logger.error(f"Entity {entity_name} has no description")
description = "(no description)"
source_id = GRAPH_FIELD_SEP.join( source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in nodes_data] + already_source_ids) set([dp["source_id"] for dp in nodes_data] + already_source_ids)
@ -1084,21 +1175,16 @@ async def _merge_edges_then_upsert(
# Process edges_data with None checks # Process edges_data with None checks
weight = sum([dp["weight"] for dp in edges_data] + already_weights) weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join( description_list = already_description + list(
already_description dict.fromkeys([dp["description"] for dp in edges_data if dp.get("description")])
+ list(
dict.fromkeys(
[dp["description"] for dp in edges_data if dp.get("description")]
)
)
) )
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
num_fragment = description.count(GRAPH_FIELD_SEP) + 1 num_fragment = len(description_list)
already_fragment = already_description.count(GRAPH_FIELD_SEP) + 1 already_fragment = already_description.count(GRAPH_FIELD_SEP) + 1
if num_fragment > 1: if num_fragment > 0:
if num_fragment >= force_llm_summary_on_merge: if num_fragment >= force_llm_summary_on_merge:
status_message = f"LLM merge E: {src_id} - {tgt_id} | {already_fragment}+{num_fragment-already_fragment}" status_message = f"LLM merging `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}"
logger.info(status_message) logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None: if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock: async with pipeline_status_lock:
@ -1106,17 +1192,23 @@ async def _merge_edges_then_upsert(
pipeline_status["history_messages"].append(status_message) pipeline_status["history_messages"].append(status_message)
description = await _handle_entity_relation_summary( description = await _handle_entity_relation_summary(
f"({src_id}, {tgt_id})", f"({src_id}, {tgt_id})",
description, description_list,
force_llm_summary_on_merge,
GRAPH_FIELD_SEP,
global_config, global_config,
llm_response_cache, llm_response_cache,
) )
else: else:
status_message = f"Merge E: {src_id} - {tgt_id} | {already_fragment}+{num_fragment-already_fragment}" status_message = f"Merging `{src_id} - {tgt_id}` | {already_fragment}+{num_fragment-already_fragment}"
logger.info(status_message) logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None: if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock: async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message) pipeline_status["history_messages"].append(status_message)
description = GRAPH_FIELD_SEP.join(description_list)
else:
logger.error(f"Edge {src_id} - {tgt_id} has no description")
description = "(no description)"
# Split all existing and new keywords into individual terms, then combine and deduplicate # Split all existing and new keywords into individual terms, then combine and deduplicate
all_keywords = set() all_keywords = set()