feat(performance): Optimize document deletion with entity/relation index

- Introduces an index mapping documents to their corresponding entities and relations. This significantly speeds up `adelete_by_doc_id` by replacing slow graph traversal with a fast key-value lookup.
- Refactors the ingestion pipeline (`merge_nodes_and_edges`) to populate this new index. Adds a one-time data migration script to backfill the index for existing data.
This commit is contained in:
yangdx 2025-08-03 09:19:02 +08:00
parent 2f0aa7ed12
commit 091f2b42c3
5 changed files with 446 additions and 69 deletions

View file

@ -654,6 +654,22 @@ class BaseGraphStorage(StorageNameSpace, ABC):
indicating whether the graph was truncated due to max_nodes limit
"""
@abstractmethod
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.
Returns:
A list of all nodes, where each node is a dictionary of its properties
"""
@abstractmethod
async def get_all_edges(self) -> list[dict]:
"""Get all edges in the graph.
Returns:
A list of all edges, where each edge is a dictionary of its properties
"""
class DocStatus(str, Enum):
"""Document processing status"""

View file

@ -393,6 +393,35 @@ class NetworkXStorage(BaseGraphStorage):
matching_edges.append(edge_data_with_nodes)
return matching_edges
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.
Returns:
A list of all nodes, where each node is a dictionary of its properties
"""
graph = await self._get_graph()
all_nodes = []
for node_id, node_data in graph.nodes(data=True):
node_data_with_id = node_data.copy()
node_data_with_id["id"] = node_id
all_nodes.append(node_data_with_id)
return all_nodes
async def get_all_edges(self) -> list[dict]:
"""Get all edges in the graph.
Returns:
A list of all edges, where each edge is a dictionary of its properties
"""
graph = await self._get_graph()
all_edges = []
for u, v, edge_data in graph.edges(data=True):
edge_data_with_nodes = edge_data.copy()
edge_data_with_nodes["source"] = u
edge_data_with_nodes["target"] = v
all_edges.append(edge_data_with_nodes)
return all_edges
async def index_done_callback(self) -> bool:
"""Save data to disk"""
async with self._storage_lock:

View file

@ -453,14 +453,26 @@ class LightRAG:
embedding_func=self.embedding_func,
)
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_TEXT_CHUNKS,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_FULL_DOCS,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_TEXT_CHUNKS,
self.full_entities: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_FULL_ENTITIES,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
self.full_relations: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=NameSpace.KV_STORE_FULL_RELATIONS,
workspace=self.workspace,
embedding_func=self.embedding_func,
)
@ -553,6 +565,8 @@ class LightRAG:
for storage in (
self.full_docs,
self.text_chunks,
self.full_entities,
self.full_relations,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
@ -576,6 +590,8 @@ class LightRAG:
for storage in (
self.full_docs,
self.text_chunks,
self.full_entities,
self.full_relations,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
@ -591,6 +607,159 @@ class LightRAG:
self._storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages")
async def _check_and_migrate_data(self):
"""Check if data migration is needed and perform migration if necessary"""
try:
# Check if migration is needed:
# 1. chunk_entity_relation_graph has entities and relations (count > 0)
# 2. full_entities and full_relations are empty
# Get all entity labels from graph
all_entity_labels = await self.chunk_entity_relation_graph.get_all_labels()
if not all_entity_labels:
logger.debug("No entities found in graph, skipping migration check")
return
# Check if full_entities and full_relations are empty
# Get all processed documents to check their entity/relation data
try:
processed_docs = await self.doc_status.get_docs_by_status(
DocStatus.PROCESSED
)
if not processed_docs:
logger.debug("No processed documents found, skipping migration")
return
# Check first few documents to see if they have full_entities/full_relations data
migration_needed = True
checked_count = 0
max_check = min(5, len(processed_docs)) # Check up to 5 documents
for doc_id in list(processed_docs.keys())[:max_check]:
checked_count += 1
entity_data = await self.full_entities.get_by_id(doc_id)
relation_data = await self.full_relations.get_by_id(doc_id)
if entity_data or relation_data:
migration_needed = False
break
if not migration_needed:
logger.debug(
"Full entities/relations data already exists, no migration needed"
)
return
logger.info(
f"Data migration needed: found {len(all_entity_labels)} entities in graph but no full_entities/full_relations data"
)
# Perform migration
await self._migrate_entity_relation_data(processed_docs)
except Exception as e:
logger.error(f"Error during migration check: {e}")
# Don't raise the error, just log it to avoid breaking initialization
except Exception as e:
logger.error(f"Error in data migration check: {e}")
# Don't raise the error to avoid breaking initialization
async def _migrate_entity_relation_data(self, processed_docs: dict):
"""Migrate existing entity and relation data to full_entities and full_relations storage"""
logger.info(f"Starting data migration for {len(processed_docs)} documents")
# Create mapping from chunk_id to doc_id
chunk_to_doc = {}
for doc_id, doc_status in processed_docs.items():
chunk_ids = (
doc_status.chunks_list
if hasattr(doc_status, "chunks_list") and doc_status.chunks_list
else []
)
for chunk_id in chunk_ids:
chunk_to_doc[chunk_id] = doc_id
# Initialize document entity and relation mappings
doc_entities = {} # doc_id -> set of entity_names
doc_relations = {} # doc_id -> set of relation_pairs (as tuples)
# Get all nodes and edges from graph
all_nodes = await self.chunk_entity_relation_graph.get_all_nodes()
all_edges = await self.chunk_entity_relation_graph.get_all_edges()
# Process all nodes once
for node in all_nodes:
if "source_id" in node:
entity_id = node.get("entity_id") or node.get("id")
if not entity_id:
continue
# Get chunk IDs from source_id
source_ids = node["source_id"].split(GRAPH_FIELD_SEP)
# Find which documents this entity belongs to
for chunk_id in source_ids:
doc_id = chunk_to_doc.get(chunk_id)
if doc_id:
if doc_id not in doc_entities:
doc_entities[doc_id] = set()
doc_entities[doc_id].add(entity_id)
# Process all edges once
for edge in all_edges:
if "source_id" in edge:
src = edge.get("source")
tgt = edge.get("target")
if not src or not tgt:
continue
# Get chunk IDs from source_id
source_ids = edge["source_id"].split(GRAPH_FIELD_SEP)
# Find which documents this relation belongs to
for chunk_id in source_ids:
doc_id = chunk_to_doc.get(chunk_id)
if doc_id:
if doc_id not in doc_relations:
doc_relations[doc_id] = set()
# Use tuple for set operations, convert to list later
doc_relations[doc_id].add((src, tgt))
# Store the results in full_entities and full_relations
migration_count = 0
# Store entities
if doc_entities:
entities_data = {}
for doc_id, entity_set in doc_entities.items():
entities_data[doc_id] = {"entity_names": list(entity_set)}
await self.full_entities.upsert(entities_data)
# Store relations
if doc_relations:
relations_data = {}
for doc_id, relation_set in doc_relations.items():
# Convert tuples back to lists
relations_data[doc_id] = {
"relation_pairs": [list(pair) for pair in relation_set]
}
await self.full_relations.upsert(relations_data)
migration_count = len(
set(list(doc_entities.keys()) + list(doc_relations.keys()))
)
# Persist the migrated data
await self.full_entities.index_done_callback()
await self.full_relations.index_done_callback()
logger.info(
f"Data migration completed: migrated {migration_count} documents with entities/relations"
)
async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels()
return text
@ -1229,6 +1398,9 @@ class LightRAG:
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
full_entities_storage=self.full_entities,
full_relations_storage=self.full_relations,
doc_id=doc_id,
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache,
@ -1401,6 +1573,8 @@ class LightRAG:
self.full_docs,
self.doc_status,
self.text_chunks,
self.full_entities,
self.full_relations,
self.llm_response_cache,
self.entities_vdb,
self.relationships_vdb,
@ -1959,21 +2133,54 @@ class LightRAG:
graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock:
try:
# Get all affected nodes and edges in batch
# logger.info(
# f"Analyzing affected entities and relationships for {len(chunk_ids)} chunks"
# )
affected_nodes = (
await self.chunk_entity_relation_graph.get_nodes_by_chunk_ids(
list(chunk_ids)
)
)
# Get affected entities and relations from full_entities and full_relations storage
doc_entities_data = await self.full_entities.get_by_id(doc_id)
doc_relations_data = await self.full_relations.get_by_id(doc_id)
affected_edges = (
await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
list(chunk_ids)
affected_nodes = []
affected_edges = []
# Get entity data from graph storage using entity names from full_entities
if doc_entities_data and "entity_names" in doc_entities_data:
entity_names = doc_entities_data["entity_names"]
# get_nodes_batch returns dict[str, dict], need to convert to list[dict]
nodes_dict = (
await self.chunk_entity_relation_graph.get_nodes_batch(
entity_names
)
)
)
for entity_name in entity_names:
node_data = nodes_dict.get(entity_name)
if node_data:
# Ensure compatibility with existing logic that expects "id" field
if "id" not in node_data:
node_data["id"] = entity_name
affected_nodes.append(node_data)
# Get relation data from graph storage using relation pairs from full_relations
if doc_relations_data and "relation_pairs" in doc_relations_data:
relation_pairs = doc_relations_data["relation_pairs"]
edge_pairs_dicts = [
{"src": pair[0], "tgt": pair[1]} for pair in relation_pairs
]
# get_edges_batch returns dict[tuple[str, str], dict], need to convert to list[dict]
edges_dict = (
await self.chunk_entity_relation_graph.get_edges_batch(
edge_pairs_dicts
)
)
for pair in relation_pairs:
src, tgt = pair[0], pair[1]
edge_key = (src, tgt)
edge_data = edges_dict.get(edge_key)
if edge_data:
# Ensure compatibility with existing logic that expects "source" and "target" fields
if "source" not in edge_data:
edge_data["source"] = src
if "target" not in edge_data:
edge_data["target"] = tgt
affected_edges.append(edge_data)
except Exception as e:
logger.error(f"Failed to analyze affected graph elements: {e}")
@ -2125,7 +2332,17 @@ class LightRAG:
f"Failed to rebuild knowledge graph: {e}"
) from e
# 9. Delete original document and status
# 9. Delete from full_entities and full_relations storage
try:
await self.full_entities.delete([doc_id])
await self.full_relations.delete([doc_id])
except Exception as e:
logger.error(f"Failed to delete from full_entities/full_relations: {e}")
raise Exception(
f"Failed to delete from full_entities/full_relations: {e}"
) from e
# 10. Delete original document and status
try:
await self.full_docs.delete([doc_id])
await self.doc_status.delete([doc_id])

View file

@ -7,6 +7,8 @@ class NameSpace:
KV_STORE_FULL_DOCS = "full_docs"
KV_STORE_TEXT_CHUNKS = "text_chunks"
KV_STORE_LLM_RESPONSE_CACHE = "llm_response_cache"
KV_STORE_FULL_ENTITIES = "full_entities"
KV_STORE_FULL_RELATIONS = "full_relations"
VECTOR_STORE_ENTITIES = "entities"
VECTOR_STORE_RELATIONSHIPS = "relationships"

View file

@ -504,9 +504,6 @@ async def _rebuild_knowledge_from_chunks(
# Re-raise the exception to notify the caller
raise task.exception()
# If all tasks completed successfully, collect results
# (No need to collect results since these tasks don't return values)
# Final status report
status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully."
if failed_entities_count > 0 or failed_relationships_count > 0:
@ -1024,6 +1021,7 @@ async def _merge_edges_then_upsert(
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
added_entities: list = None, # New parameter to track entities added during edge processing
):
if src_id == tgt_id:
return None
@ -1105,17 +1103,27 @@ async def _merge_edges_then_upsert(
for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"entity_id": need_insert_id,
"source_id": source_id,
"description": description,
node_data = {
"entity_id": need_insert_id,
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN",
"file_path": file_path,
"created_at": int(time.time()),
}
await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data)
# Track entities added during edge processing
if added_entities is not None:
entity_data = {
"entity_name": need_insert_id,
"entity_type": "UNKNOWN",
"description": description,
"source_id": source_id,
"file_path": file_path,
"created_at": int(time.time()),
},
)
}
added_entities.append(entity_data)
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
@ -1178,6 +1186,9 @@ async def merge_nodes_and_edges(
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict[str, str],
full_entities_storage: BaseKVStorage = None,
full_relations_storage: BaseKVStorage = None,
doc_id: str = None,
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
@ -1185,7 +1196,12 @@ async def merge_nodes_and_edges(
total_files: int = 0,
file_path: str = "unknown_source",
) -> None:
"""Merge nodes and edges from extraction results
"""Two-phase merge: process all entities first, then all relationships
This approach ensures data consistency by:
1. Phase 1: Process all entities concurrently
2. Phase 2: Process all relationships concurrently (may add missing entities)
3. Phase 3: Update full_entities and full_relations storage with final results
Args:
chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships
@ -1193,9 +1209,15 @@ async def merge_nodes_and_edges(
entity_vdb: Entity vector database
relationships_vdb: Relationship vector database
global_config: Global configuration
full_entities_storage: Storage for document entity lists
full_relations_storage: Storage for document relation lists
doc_id: Document ID for storage indexing
pipeline_status: Pipeline status dictionary
pipeline_status_lock: Lock for pipeline status
llm_response_cache: LLM response cache
current_file_number: Current file number for logging
total_files: Total files for logging
file_path: File path for logging
"""
# Collect all nodes and edges from all chunks
@ -1212,11 +1234,9 @@ async def merge_nodes_and_edges(
sorted_edge_key = tuple(sorted(edge_key))
all_edges[sorted_edge_key].extend(edges)
# Centralized processing of all nodes and edges
total_entities_count = len(all_nodes)
total_relations_count = len(all_edges)
# Merge nodes and edges
log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
async with pipeline_status_lock:
@ -1227,8 +1247,8 @@ async def merge_nodes_and_edges(
graph_max_async = global_config.get("llm_model_max_async", 4) * 2
semaphore = asyncio.Semaphore(graph_max_async)
# Process and update all entities and relationships in parallel
log_message = f"Processing: {total_entities_count} entities and {total_relations_count} relations (async: {graph_max_async})"
# ===== Phase 1: Process all entities concurrently =====
log_message = f"Phase 1: Processing {total_entities_count} entities (async: {graph_max_async})"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
@ -1263,18 +1283,53 @@ async def merge_nodes_and_edges(
await entity_vdb.upsert(data_for_vdb)
return entity_data
# Create entity processing tasks
entity_tasks = []
for entity_name, entities in all_nodes.items():
task = asyncio.create_task(_locked_process_entity_name(entity_name, entities))
entity_tasks.append(task)
# Execute entity tasks with error handling
processed_entities = []
if entity_tasks:
done, pending = await asyncio.wait(
entity_tasks, return_when=asyncio.FIRST_EXCEPTION
)
# Check if any task raised an exception
for task in done:
if task.exception():
# If a task failed, cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Re-raise the exception to notify the caller
raise task.exception()
# If all tasks completed successfully, collect results
processed_entities = [task.result() for task in entity_tasks]
# ===== Phase 2: Process all relationships concurrently =====
log_message = f"Phase 2: Processing {total_relations_count} relations (async: {graph_max_async})"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
async def _locked_process_edges(edge_key, edges):
async with semaphore:
workspace = global_config.get("workspace", "")
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
# Sort the edge_key components to ensure consistent lock key generation
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
# logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
async with get_storage_keyed_lock(
sorted_edge_key,
namespace=namespace,
enable_logging=False,
):
added_entities = [] # Track entities added during edge processing
edge_data = await _merge_edges_then_upsert(
edge_key[0],
edge_key[1],
@ -1284,9 +1339,11 @@ async def merge_nodes_and_edges(
pipeline_status,
pipeline_status_lock,
llm_response_cache,
added_entities, # Pass list to collect added entities
)
if edge_data is None:
return None
return None, []
if relationships_vdb is not None:
data_for_vdb = {
@ -1303,50 +1360,106 @@ async def merge_nodes_and_edges(
}
}
await relationships_vdb.upsert(data_for_vdb)
return edge_data
return edge_data, added_entities
# Create a single task queue for both entities and edges
tasks = []
# Create relationship processing tasks
edge_tasks = []
for edge_key, edges in all_edges.items():
task = asyncio.create_task(_locked_process_edges(edge_key, edges))
edge_tasks.append(task)
# Add entity processing tasks
for entity_name, entities in all_nodes.items():
tasks.append(
asyncio.create_task(_locked_process_entity_name(entity_name, entities))
# Execute relationship tasks with error handling
processed_edges = []
all_added_entities = []
if edge_tasks:
done, pending = await asyncio.wait(
edge_tasks, return_when=asyncio.FIRST_EXCEPTION
)
# Add edge processing tasks
for edge_key, edges in all_edges.items():
tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges)))
# Check if any task raised an exception
for task in done:
if task.exception():
# If a task failed, cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Re-raise the exception to notify the caller
raise task.exception()
# Check if there are any tasks to process
if not tasks:
log_message = f"No entities or relationships to process for {file_path}"
logger.info(log_message)
if pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
return
# If all tasks completed successfully, collect results
for task in edge_tasks:
edge_data, added_entities = task.result()
if edge_data is not None:
processed_edges.append(edge_data)
all_added_entities.extend(added_entities)
# Execute all tasks in parallel with semaphore control and early failure detection
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
# ===== Phase 3: Update full_entities and full_relations storage =====
if full_entities_storage and full_relations_storage and doc_id:
try:
# Merge all entities: original entities + entities added during edge processing
final_entity_names = set()
# Check if any task raised an exception
for task in done:
if task.exception():
# If a task failed, cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Add original processed entities
for entity_data in processed_entities:
if entity_data and entity_data.get("entity_name"):
final_entity_names.add(entity_data["entity_name"])
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Add entities that were added during relationship processing
for added_entity in all_added_entities:
if added_entity and added_entity.get("entity_name"):
final_entity_names.add(added_entity["entity_name"])
# Re-raise the exception to notify the caller
raise task.exception()
# Collect all relation pairs
final_relation_pairs = set()
for edge_data in processed_edges:
if edge_data:
src_id = edge_data.get("src_id")
tgt_id = edge_data.get("tgt_id")
if src_id and tgt_id:
relation_pair = tuple(sorted([src_id, tgt_id]))
final_relation_pairs.add(relation_pair)
# If all tasks completed successfully, collect results
# (No need to collect results since these tasks don't return values)
# Update storage
if final_entity_names:
await full_entities_storage.upsert(
{
doc_id: {
"entity_names": list(final_entity_names),
"count": len(final_entity_names),
}
}
)
if final_relation_pairs:
await full_relations_storage.upsert(
{
doc_id: {
"relation_pairs": [
list(pair) for pair in final_relation_pairs
],
"count": len(final_relation_pairs),
}
}
)
logger.debug(
f"Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations"
)
except Exception as e:
logger.error(
f"Failed to update entity-relation index for document {doc_id}: {e}"
)
# Don't raise exception to avoid affecting main flow
log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} added entities, {len(processed_edges)} relations"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
async def extract_entities(