fix: sync core modules with upstream for compatibility

This commit is contained in:
Raphaël MANSUY 2025-12-04 19:10:46 +08:00
parent 395b2d82de
commit ed73def994
4 changed files with 472 additions and 101 deletions

View file

@ -13,6 +13,7 @@ DEFAULT_MAX_GRAPH_NODES = 1000
# Default values for extraction settings # Default values for extraction settings
DEFAULT_SUMMARY_LANGUAGE = "English" # Default language for document processing DEFAULT_SUMMARY_LANGUAGE = "English" # Default language for document processing
DEFAULT_MAX_GLEANING = 1 DEFAULT_MAX_GLEANING = 1
DEFAULT_ENTITY_NAME_MAX_LENGTH = 256
# Number of description fragments to trigger LLM summary # Number of description fragments to trigger LLM summary
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 8 DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 8
@ -37,7 +38,7 @@ DEFAULT_ENTITY_TYPES = [
"NaturalObject", "NaturalObject",
] ]
# Separator for graph fields # Separator for: description, source_id and relation-key fields(Can not be changed after data inserted)
GRAPH_FIELD_SEP = "<SEP>" GRAPH_FIELD_SEP = "<SEP>"
# Query and retrieval configuration defaults # Query and retrieval configuration defaults
@ -58,17 +59,20 @@ DEFAULT_MIN_RERANK_SCORE = 0.0
DEFAULT_RERANK_BINDING = "null" DEFAULT_RERANK_BINDING = "null"
# Default source ids limit in meta data for entity and relation # Default source ids limit in meta data for entity and relation
DEFAULT_MAX_SOURCE_IDS_PER_ENTITY = 3 DEFAULT_MAX_SOURCE_IDS_PER_ENTITY = 300
DEFAULT_MAX_SOURCE_IDS_PER_RELATION = 3 DEFAULT_MAX_SOURCE_IDS_PER_RELATION = 300
SOURCE_IDS_LIMIT_METHOD_KEEP = "KEEP" # Keep oldest ### control chunk_ids limitation method: FIFO, FIFO
SOURCE_IDS_LIMIT_METHOD_FIFO = "FIFO" # First In First Out (Keep newest) ### FIFO: First in first out
DEFAULT_SOURCE_IDS_LIMIT_METHOD = SOURCE_IDS_LIMIT_METHOD_KEEP ### KEEP: Keep oldest (less merge action and faster)
SOURCE_IDS_LIMIT_METHOD_KEEP = "KEEP"
SOURCE_IDS_LIMIT_METHOD_FIFO = "FIFO"
DEFAULT_SOURCE_IDS_LIMIT_METHOD = SOURCE_IDS_LIMIT_METHOD_FIFO
VALID_SOURCE_IDS_LIMIT_METHODS = { VALID_SOURCE_IDS_LIMIT_METHODS = {
SOURCE_IDS_LIMIT_METHOD_KEEP, SOURCE_IDS_LIMIT_METHOD_KEEP,
SOURCE_IDS_LIMIT_METHOD_FIFO, SOURCE_IDS_LIMIT_METHOD_FIFO,
} }
# Default file_path limit in meta data for entity and relation (Use same limit method as source_ids) # Maximum number of file paths stored in entity/relation file_path field (For displayed only, does not affect query performance)
DEFAULT_MAX_FILE_PATHS = 2 DEFAULT_MAX_FILE_PATHS = 100
# Field length of file_path in Milvus Schema for entity and relation (Should not be changed) # Field length of file_path in Milvus Schema for entity and relation (Should not be changed)
# file_path must store all file paths up to the DEFAULT_MAX_FILE_PATHS limit within the metadata. # file_path must store all file paths up to the DEFAULT_MAX_FILE_PATHS limit within the metadata.

View file

@ -96,3 +96,41 @@ class PipelineNotInitializedError(KeyError):
f" await initialize_pipeline_status(workspace='your_workspace')" f" await initialize_pipeline_status(workspace='your_workspace')"
) )
super().__init__(msg) super().__init__(msg)
class PipelineCancelledException(Exception):
"""Raised when pipeline processing is cancelled by user request."""
def __init__(self, message: str = "User cancelled"):
super().__init__(message)
self.message = message
class ChunkTokenLimitExceededError(ValueError):
"""Raised when a chunk exceeds the configured token limit."""
def __init__(
self,
chunk_tokens: int,
chunk_token_limit: int,
chunk_preview: str | None = None,
) -> None:
preview = chunk_preview.strip() if chunk_preview else None
truncated_preview = preview[:80] if preview else None
preview_note = f" Preview: '{truncated_preview}'" if truncated_preview else ""
message = (
f"Chunk token length {chunk_tokens} exceeds chunk_token_size {chunk_token_limit}."
f"{preview_note}"
)
super().__init__(message)
self.chunk_tokens = chunk_tokens
self.chunk_token_limit = chunk_token_limit
self.chunk_preview = truncated_preview
class QdrantMigrationError(Exception):
"""Raised when Qdrant data migration from legacy collections fails."""
def __init__(self, message: str):
super().__init__(message)
self.message = message

View file

@ -260,13 +260,15 @@ class LightRAG:
- `content`: The text to be split into chunks. - `content`: The text to be split into chunks.
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens. - `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
- `split_by_character_only`: If True, the text is split only on the specified character. - `split_by_character_only`: If True, the text is split only on the specified character.
- `chunk_token_size`: The maximum number of tokens per chunk.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks. - `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
- `chunk_token_size`: The maximum number of tokens per chunk.
The function should return a list of dictionaries (or an awaitable that resolves to a list), The function should return a list of dictionaries (or an awaitable that resolves to a list),
where each dictionary contains the following keys: where each dictionary contains the following keys:
- `tokens`: The number of tokens in the chunk. - `tokens` (int): The number of tokens in the chunk.
- `content`: The text content of the chunk. - `content` (str): The text content of the chunk.
- `chunk_order_index` (int): Zero-based index indicating the chunk's order in the document.
Defaults to `chunking_by_token_size` if not specified. Defaults to `chunking_by_token_size` if not specified.
""" """
@ -2948,6 +2950,26 @@ class LightRAG:
data across different storage layers are removed or rebuiled. If entities or relationships data across different storage layers are removed or rebuiled. If entities or relationships
are partially affected, they will be rebuilded using LLM cached from remaining documents. are partially affected, they will be rebuilded using LLM cached from remaining documents.
**Concurrency Control Design:**
This function implements a pipeline-based concurrency control to prevent data corruption:
1. **Single Document Deletion** (when WE acquire pipeline):
- Sets job_name to "Single document deletion" (NOT starting with "deleting")
- Prevents other adelete_by_doc_id calls from running concurrently
- Ensures exclusive access to graph operations for this deletion
2. **Batch Document Deletion** (when background_delete_documents acquires pipeline):
- Sets job_name to "Deleting {N} Documents" (starts with "deleting")
- Allows multiple adelete_by_doc_id calls to join the deletion queue
- Each call validates the job name to ensure it's part of a deletion operation
The validation logic `if not job_name.startswith("deleting") or "document" not in job_name`
ensures that:
- adelete_by_doc_id can only run when pipeline is idle OR during batch deletion
- Prevents concurrent single deletions that could cause race conditions
- Rejects operations when pipeline is busy with non-deletion tasks
Args: Args:
doc_id (str): The unique identifier of the document to be deleted. doc_id (str): The unique identifier of the document to be deleted.
delete_llm_cache (bool): Whether to delete cached LLM extraction results delete_llm_cache (bool): Whether to delete cached LLM extraction results
@ -2955,17 +2977,13 @@ class LightRAG:
Returns: Returns:
DeletionResult: An object containing the outcome of the deletion process. DeletionResult: An object containing the outcome of the deletion process.
- `status` (str): "success", "not_found", or "failure". - `status` (str): "success", "not_found", "not_allowed", or "failure".
- `doc_id` (str): The ID of the document attempted to be deleted. - `doc_id` (str): The ID of the document attempted to be deleted.
- `message` (str): A summary of the operation's result. - `message` (str): A summary of the operation's result.
- `status_code` (int): HTTP status code (e.g., 200, 404, 500). - `status_code` (int): HTTP status code (e.g., 200, 404, 403, 500).
- `file_path` (str | None): The file path of the deleted document, if available. - `file_path` (str | None): The file path of the deleted document, if available.
""" """
deletion_operations_started = False # Get pipeline status shared data and lock for validation
original_exception = None
doc_llm_cache_ids: list[str] = []
# Get pipeline status shared data and lock for status updates
pipeline_status = await get_namespace_data( pipeline_status = await get_namespace_data(
"pipeline_status", workspace=self.workspace "pipeline_status", workspace=self.workspace
) )
@ -2973,6 +2991,48 @@ class LightRAG:
"pipeline_status", workspace=self.workspace "pipeline_status", workspace=self.workspace
) )
# Track whether WE acquired the pipeline
we_acquired_pipeline = False
# Check and acquire pipeline if needed
async with pipeline_status_lock:
if not pipeline_status.get("busy", False):
# Pipeline is idle - WE acquire it for this deletion
we_acquired_pipeline = True
pipeline_status.update(
{
"busy": True,
"job_name": "Single document deletion",
"job_start": datetime.now(timezone.utc).isoformat(),
"docs": 1,
"batchs": 1,
"cur_batch": 0,
"request_pending": False,
"cancellation_requested": False,
"latest_message": f"Starting deletion for document: {doc_id}",
}
)
# Initialize history messages
pipeline_status["history_messages"][:] = [
f"Starting deletion for document: {doc_id}"
]
else:
# Pipeline already busy - verify it's a deletion job
job_name = pipeline_status.get("job_name", "").lower()
if not job_name.startswith("deleting") or "document" not in job_name:
return DeletionResult(
status="not_allowed",
doc_id=doc_id,
message=f"Deletion not allowed: current job '{pipeline_status.get('job_name')}' is not a document deletion job",
status_code=403,
file_path=None,
)
# Pipeline is busy with deletion - proceed without acquiring
deletion_operations_started = False
original_exception = None
doc_llm_cache_ids: list[str] = []
async with pipeline_status_lock: async with pipeline_status_lock:
log_message = f"Starting deletion process for document {doc_id}" log_message = f"Starting deletion process for document {doc_id}"
logger.info(log_message) logger.info(log_message)
@ -3585,6 +3645,18 @@ class LightRAG:
f"No deletion operations were started for document {doc_id}, skipping persistence" f"No deletion operations were started for document {doc_id}, skipping persistence"
) )
# Release pipeline only if WE acquired it
if we_acquired_pipeline:
async with pipeline_status_lock:
pipeline_status["busy"] = False
pipeline_status["cancellation_requested"] = False
completion_msg = (
f"Deletion process completed for document: {doc_id}"
)
pipeline_status["latest_message"] = completion_msg
pipeline_status["history_messages"].append(completion_msg)
logger.info(completion_msg)
async def adelete_by_entity(self, entity_name: str) -> DeletionResult: async def adelete_by_entity(self, entity_name: str) -> DeletionResult:
"""Asynchronously delete an entity and all its relationships. """Asynchronously delete an entity and all its relationships.

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from functools import partial from functools import partial
from pathlib import Path
import asyncio import asyncio
import json import json
@ -7,6 +8,10 @@ import json_repair
from typing import Any, AsyncIterator, overload, Literal from typing import Any, AsyncIterator, overload, Literal
from collections import Counter, defaultdict from collections import Counter, defaultdict
from lightrag.exceptions import (
PipelineCancelledException,
ChunkTokenLimitExceededError,
)
from lightrag.utils import ( from lightrag.utils import (
logger, logger,
compute_mdhash_id, compute_mdhash_id,
@ -67,7 +72,7 @@ from dotenv import load_dotenv
# use the .env that is inside the current folder # use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance # allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file # the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False) load_dotenv(dotenv_path=Path(__file__).resolve().parent / ".env", override=False)
def _truncate_entity_identifier( def _truncate_entity_identifier(
@ -96,8 +101,8 @@ def chunking_by_token_size(
content: str, content: str,
split_by_character: str | None = None, split_by_character: str | None = None,
split_by_character_only: bool = False, split_by_character_only: bool = False,
overlap_token_size: int = 128, chunk_overlap_token_size: int = 100,
max_token_size: int = 1024, chunk_token_size: int = 1200,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
tokens = tokenizer.encode(content) tokens = tokenizer.encode(content)
results: list[dict[str, Any]] = [] results: list[dict[str, Any]] = []
@ -107,19 +112,30 @@ def chunking_by_token_size(
if split_by_character_only: if split_by_character_only:
for chunk in raw_chunks: for chunk in raw_chunks:
_tokens = tokenizer.encode(chunk) _tokens = tokenizer.encode(chunk)
if len(_tokens) > chunk_token_size:
logger.warning(
"Chunk split_by_character exceeds token limit: len=%d limit=%d",
len(_tokens),
chunk_token_size,
)
raise ChunkTokenLimitExceededError(
chunk_tokens=len(_tokens),
chunk_token_limit=chunk_token_size,
chunk_preview=chunk[:120],
)
new_chunks.append((len(_tokens), chunk)) new_chunks.append((len(_tokens), chunk))
else: else:
for chunk in raw_chunks: for chunk in raw_chunks:
_tokens = tokenizer.encode(chunk) _tokens = tokenizer.encode(chunk)
if len(_tokens) > max_token_size: if len(_tokens) > chunk_token_size:
for start in range( for start in range(
0, len(_tokens), max_token_size - overlap_token_size 0, len(_tokens), chunk_token_size - chunk_overlap_token_size
): ):
chunk_content = tokenizer.decode( chunk_content = tokenizer.decode(
_tokens[start : start + max_token_size] _tokens[start : start + chunk_token_size]
) )
new_chunks.append( new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content) (min(chunk_token_size, len(_tokens) - start), chunk_content)
) )
else: else:
new_chunks.append((len(_tokens), chunk)) new_chunks.append((len(_tokens), chunk))
@ -133,12 +149,12 @@ def chunking_by_token_size(
) )
else: else:
for index, start in enumerate( for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size) range(0, len(tokens), chunk_token_size - chunk_overlap_token_size)
): ):
chunk_content = tokenizer.decode(tokens[start : start + max_token_size]) chunk_content = tokenizer.decode(tokens[start : start + chunk_token_size])
results.append( results.append(
{ {
"tokens": min(max_token_size, len(tokens) - start), "tokens": min(chunk_token_size, len(tokens) - start),
"content": chunk_content.strip(), "content": chunk_content.strip(),
"chunk_order_index": index, "chunk_order_index": index,
} }
@ -343,6 +359,20 @@ async def _summarize_descriptions(
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
cache_type="summary", cache_type="summary",
) )
# Check summary token length against embedding limit
embedding_token_limit = global_config.get("embedding_token_limit")
if embedding_token_limit is not None and summary:
tokenizer = global_config["tokenizer"]
summary_token_count = len(tokenizer.encode(summary))
threshold = int(embedding_token_limit * 0.9)
if summary_token_count > threshold:
logger.warning(
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
f"({embedding_token_limit}) for {description_type}: {description_name}"
)
return summary return summary
@ -367,8 +397,8 @@ async def _handle_single_entity_extraction(
# Validate entity name after all cleaning steps # Validate entity name after all cleaning steps
if not entity_name or not entity_name.strip(): if not entity_name or not entity_name.strip():
logger.warning( logger.info(
f"Entity extraction error: entity name became empty after cleaning. Original: '{record_attributes[1]}'" f"Empty entity name found after sanitization. Original: '{record_attributes[1]}'"
) )
return None return None
@ -444,14 +474,14 @@ async def _handle_single_relationship_extraction(
# Validate entity names after all cleaning steps # Validate entity names after all cleaning steps
if not source: if not source:
logger.warning( logger.info(
f"Relationship extraction error: source entity became empty after cleaning. Original: '{record_attributes[1]}'" f"Empty source entity found after sanitization. Original: '{record_attributes[1]}'"
) )
return None return None
if not target: if not target:
logger.warning( logger.info(
f"Relationship extraction error: target entity became empty after cleaning. Original: '{record_attributes[2]}'" f"Empty target entity found after sanitization. Original: '{record_attributes[2]}'"
) )
return None return None
@ -500,7 +530,7 @@ async def _handle_single_relationship_extraction(
return None return None
async def _rebuild_knowledge_from_chunks( async def rebuild_knowledge_from_chunks(
entities_to_rebuild: dict[str, list[str]], entities_to_rebuild: dict[str, list[str]],
relationships_to_rebuild: dict[tuple[str, str], list[str]], relationships_to_rebuild: dict[tuple[str, str], list[str]],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@ -675,14 +705,6 @@ async def _rebuild_knowledge_from_chunks(
entity_chunks_storage=entity_chunks_storage, entity_chunks_storage=entity_chunks_storage,
) )
rebuilt_entities_count += 1 rebuilt_entities_count += 1
status_message = (
f"Rebuild `{entity_name}` from {len(chunk_ids)} chunks"
)
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
except Exception as e: except Exception as e:
failed_entities_count += 1 failed_entities_count += 1
status_message = f"Failed to rebuild `{entity_name}`: {e}" status_message = f"Failed to rebuild `{entity_name}`: {e}"
@ -708,6 +730,7 @@ async def _rebuild_knowledge_from_chunks(
await _rebuild_single_relationship( await _rebuild_single_relationship(
knowledge_graph_inst=knowledge_graph_inst, knowledge_graph_inst=knowledge_graph_inst,
relationships_vdb=relationships_vdb, relationships_vdb=relationships_vdb,
entities_vdb=entities_vdb,
src=src, src=src,
tgt=tgt, tgt=tgt,
chunk_ids=chunk_ids, chunk_ids=chunk_ids,
@ -715,13 +738,14 @@ async def _rebuild_knowledge_from_chunks(
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
global_config=global_config, global_config=global_config,
relation_chunks_storage=relation_chunks_storage, relation_chunks_storage=relation_chunks_storage,
entity_chunks_storage=entity_chunks_storage,
pipeline_status=pipeline_status, pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock, pipeline_status_lock=pipeline_status_lock,
) )
rebuilt_relationships_count += 1 rebuilt_relationships_count += 1
except Exception as e: except Exception as e:
failed_relationships_count += 1 failed_relationships_count += 1
status_message = f"Failed to rebuild `{src} - {tgt}`: {e}" status_message = f"Failed to rebuild `{src}`~`{tgt}`: {e}"
logger.info(status_message) # Per requirement, change to info logger.info(status_message) # Per requirement, change to info
if pipeline_status is not None and pipeline_status_lock is not None: if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock: async with pipeline_status_lock:
@ -1290,6 +1314,7 @@ async def _rebuild_single_entity(
async def _rebuild_single_relationship( async def _rebuild_single_relationship(
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
entities_vdb: BaseVectorStorage,
src: str, src: str,
tgt: str, tgt: str,
chunk_ids: list[str], chunk_ids: list[str],
@ -1297,6 +1322,7 @@ async def _rebuild_single_relationship(
llm_response_cache: BaseKVStorage, llm_response_cache: BaseKVStorage,
global_config: dict[str, str], global_config: dict[str, str],
relation_chunks_storage: BaseKVStorage | None = None, relation_chunks_storage: BaseKVStorage | None = None,
entity_chunks_storage: BaseKVStorage | None = None,
pipeline_status: dict | None = None, pipeline_status: dict | None = None,
pipeline_status_lock=None, pipeline_status_lock=None,
) -> None: ) -> None:
@ -1440,9 +1466,69 @@ async def _rebuild_single_relationship(
else current_relationship.get("file_path", "unknown_source"), else current_relationship.get("file_path", "unknown_source"),
"truncate": truncation_info, "truncate": truncation_info,
} }
# Ensure both endpoint nodes exist before writing the edge back
# (certain storage backends require pre-existing nodes).
node_description = (
updated_relationship_data["description"]
if updated_relationship_data.get("description")
else current_relationship.get("description", "")
)
node_source_id = updated_relationship_data.get("source_id", "")
node_file_path = updated_relationship_data.get("file_path", "unknown_source")
for node_id in {src, tgt}:
if not (await knowledge_graph_inst.has_node(node_id)):
node_created_at = int(time.time())
node_data = {
"entity_id": node_id,
"source_id": node_source_id,
"description": node_description,
"entity_type": "UNKNOWN",
"file_path": node_file_path,
"created_at": node_created_at,
"truncate": "",
}
await knowledge_graph_inst.upsert_node(node_id, node_data=node_data)
# Update entity_chunks_storage for the newly created entity
if entity_chunks_storage is not None and limited_chunk_ids:
await entity_chunks_storage.upsert(
{
node_id: {
"chunk_ids": limited_chunk_ids,
"count": len(limited_chunk_ids),
}
}
)
# Update entity_vdb for the newly created entity
if entities_vdb is not None:
entity_vdb_id = compute_mdhash_id(node_id, prefix="ent-")
entity_content = f"{node_id}\n{node_description}"
vdb_data = {
entity_vdb_id: {
"content": entity_content,
"entity_name": node_id,
"source_id": node_source_id,
"entity_type": "UNKNOWN",
"file_path": node_file_path,
}
}
await safe_vdb_operation_with_exception(
operation=lambda payload=vdb_data: entities_vdb.upsert(payload),
operation_name="rebuild_added_entity_upsert",
entity_name=node_id,
max_retries=3,
retry_delay=0.1,
)
await knowledge_graph_inst.upsert_edge(src, tgt, updated_relationship_data) await knowledge_graph_inst.upsert_edge(src, tgt, updated_relationship_data)
# Update relationship in vector database # Update relationship in vector database
# Sort src and tgt to ensure consistent ordering (smaller string first)
if src > tgt:
src, tgt = tgt, src
try: try:
rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-") rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-")
rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-") rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-")
@ -1485,7 +1571,7 @@ async def _rebuild_single_relationship(
raise # Re-raise exception raise # Re-raise exception
# Log rebuild completion with truncation info # Log rebuild completion with truncation info
status_message = f"Rebuild `{src} - {tgt}` from {len(chunk_ids)} chunks" status_message = f"Rebuild `{src}`~`{tgt}` from {len(chunk_ids)} chunks"
if truncation_info: if truncation_info:
status_message += f" ({truncation_info})" status_message += f" ({truncation_info})"
# Add truncation info from apply_source_ids_limit if truncation occurred # Add truncation info from apply_source_ids_limit if truncation occurred
@ -1637,6 +1723,12 @@ async def _merge_nodes_then_upsert(
logger.error(f"Entity {entity_name} has no description") logger.error(f"Entity {entity_name} has no description")
raise ValueError(f"Entity {entity_name} has no description") raise ValueError(f"Entity {entity_name} has no description")
# Check for cancellation before LLM summary
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException("User cancelled during entity summary")
# 8. Get summary description an LLM usage status # 8. Get summary description an LLM usage status
description, llm_was_used = await _handle_entity_relation_summary( description, llm_was_used = await _handle_entity_relation_summary(
"Entity", "Entity",
@ -1789,6 +1881,7 @@ async def _merge_edges_then_upsert(
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
added_entities: list = None, # New parameter to track entities added during edge processing added_entities: list = None, # New parameter to track entities added during edge processing
relation_chunks_storage: BaseKVStorage | None = None, relation_chunks_storage: BaseKVStorage | None = None,
entity_chunks_storage: BaseKVStorage | None = None,
): ):
if src_id == tgt_id: if src_id == tgt_id:
return None return None
@ -1957,6 +2050,14 @@ async def _merge_edges_then_upsert(
logger.error(f"Relation {src_id}~{tgt_id} has no description") logger.error(f"Relation {src_id}~{tgt_id} has no description")
raise ValueError(f"Relation {src_id}~{tgt_id} has no description") raise ValueError(f"Relation {src_id}~{tgt_id} has no description")
# Check for cancellation before LLM summary
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during relation summary"
)
# 8. Get summary description an LLM usage status # 8. Get summary description an LLM usage status
description, llm_was_used = await _handle_entity_relation_summary( description, llm_was_used = await _handle_entity_relation_summary(
"Relation", "Relation",
@ -2065,7 +2166,11 @@ async def _merge_edges_then_upsert(
# 11. Update both graph and vector db # 11. Update both graph and vector db
for need_insert_id in [src_id, tgt_id]: for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)): # Optimization: Use get_node instead of has_node + get_node
existing_node = await knowledge_graph_inst.get_node(need_insert_id)
if existing_node is None:
# Node doesn't exist - create new node
node_created_at = int(time.time()) node_created_at = int(time.time())
node_data = { node_data = {
"entity_id": need_insert_id, "entity_id": need_insert_id,
@ -2078,6 +2183,19 @@ async def _merge_edges_then_upsert(
} }
await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data) await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data)
# Update entity_chunks_storage for the newly created entity
if entity_chunks_storage is not None:
chunk_ids = [chunk_id for chunk_id in full_source_ids if chunk_id]
if chunk_ids:
await entity_chunks_storage.upsert(
{
need_insert_id: {
"chunk_ids": chunk_ids,
"count": len(chunk_ids),
}
}
)
if entity_vdb is not None: if entity_vdb is not None:
entity_vdb_id = compute_mdhash_id(need_insert_id, prefix="ent-") entity_vdb_id = compute_mdhash_id(need_insert_id, prefix="ent-")
entity_content = f"{need_insert_id}\n{description}" entity_content = f"{need_insert_id}\n{description}"
@ -2109,6 +2227,109 @@ async def _merge_edges_then_upsert(
"created_at": node_created_at, "created_at": node_created_at,
} }
added_entities.append(entity_data) added_entities.append(entity_data)
else:
# Node exists - update its source_ids by merging with new source_ids
updated = False # Track if any update occurred
# 1. Get existing full source_ids from entity_chunks_storage
existing_full_source_ids = []
if entity_chunks_storage is not None:
stored_chunks = await entity_chunks_storage.get_by_id(need_insert_id)
if stored_chunks and isinstance(stored_chunks, dict):
existing_full_source_ids = [
chunk_id
for chunk_id in stored_chunks.get("chunk_ids", [])
if chunk_id
]
# If not in entity_chunks_storage, get from graph database
if not existing_full_source_ids:
if existing_node.get("source_id"):
existing_full_source_ids = existing_node["source_id"].split(
GRAPH_FIELD_SEP
)
# 2. Merge with new source_ids from this relationship
new_source_ids_from_relation = [
chunk_id for chunk_id in source_ids if chunk_id
]
merged_full_source_ids = merge_source_ids(
existing_full_source_ids, new_source_ids_from_relation
)
# 3. Save merged full list to entity_chunks_storage (conditional)
if (
entity_chunks_storage is not None
and merged_full_source_ids != existing_full_source_ids
):
updated = True
await entity_chunks_storage.upsert(
{
need_insert_id: {
"chunk_ids": merged_full_source_ids,
"count": len(merged_full_source_ids),
}
}
)
# 4. Apply source_ids limit for graph and vector db
limit_method = global_config.get(
"source_ids_limit_method", SOURCE_IDS_LIMIT_METHOD_KEEP
)
max_source_limit = global_config.get("max_source_ids_per_entity")
limited_source_ids = apply_source_ids_limit(
merged_full_source_ids,
max_source_limit,
limit_method,
identifier=f"`{need_insert_id}`",
)
# 5. Update graph database and vector database with limited source_ids (conditional)
limited_source_id_str = GRAPH_FIELD_SEP.join(limited_source_ids)
if limited_source_id_str != existing_node.get("source_id", ""):
updated = True
updated_node_data = {
**existing_node,
"source_id": limited_source_id_str,
}
await knowledge_graph_inst.upsert_node(
need_insert_id, node_data=updated_node_data
)
# Update vector database
if entity_vdb is not None:
entity_vdb_id = compute_mdhash_id(need_insert_id, prefix="ent-")
entity_content = (
f"{need_insert_id}\n{existing_node.get('description', '')}"
)
vdb_data = {
entity_vdb_id: {
"content": entity_content,
"entity_name": need_insert_id,
"source_id": limited_source_id_str,
"entity_type": existing_node.get("entity_type", "UNKNOWN"),
"file_path": existing_node.get(
"file_path", "unknown_source"
),
}
}
await safe_vdb_operation_with_exception(
operation=lambda payload=vdb_data: entity_vdb.upsert(payload),
operation_name="existing_entity_update",
entity_name=need_insert_id,
max_retries=3,
retry_delay=0.1,
)
# 6. Log once at the end if any update occurred
if updated:
status_message = f"Chunks appended from relation: `{need_insert_id}`"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = status_message
pipeline_status["history_messages"].append(status_message)
edge_created_at = int(time.time()) edge_created_at = int(time.time())
await knowledge_graph_inst.upsert_edge( await knowledge_graph_inst.upsert_edge(
@ -2137,6 +2358,10 @@ async def _merge_edges_then_upsert(
weight=weight, weight=weight,
) )
# Sort src_id and tgt_id to ensure consistent ordering (smaller string first)
if src_id > tgt_id:
src_id, tgt_id = tgt_id, src_id
if relationships_vdb is not None: if relationships_vdb is not None:
rel_vdb_id = compute_mdhash_id(src_id + tgt_id, prefix="rel-") rel_vdb_id = compute_mdhash_id(src_id + tgt_id, prefix="rel-")
rel_vdb_id_reverse = compute_mdhash_id(tgt_id + src_id, prefix="rel-") rel_vdb_id_reverse = compute_mdhash_id(tgt_id + src_id, prefix="rel-")
@ -2214,6 +2439,12 @@ async def merge_nodes_and_edges(
file_path: File path for logging file_path: File path for logging
""" """
# Check for cancellation at the start of merge
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException("User cancelled during merge phase")
# Collect all nodes and edges from all chunks # Collect all nodes and edges from all chunks
all_nodes = defaultdict(list) all_nodes = defaultdict(list)
all_edges = defaultdict(list) all_edges = defaultdict(list)
@ -2250,6 +2481,14 @@ async def merge_nodes_and_edges(
async def _locked_process_entity_name(entity_name, entities): async def _locked_process_entity_name(entity_name, entities):
async with semaphore: async with semaphore:
# Check for cancellation before processing entity
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during entity merge"
)
workspace = global_config.get("workspace", "") workspace = global_config.get("workspace", "")
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
async with get_storage_keyed_lock( async with get_storage_keyed_lock(
@ -2343,6 +2582,14 @@ async def merge_nodes_and_edges(
async def _locked_process_edges(edge_key, edges): async def _locked_process_edges(edge_key, edges):
async with semaphore: async with semaphore:
# Check for cancellation before processing edges
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during relation merge"
)
workspace = global_config.get("workspace", "") workspace = global_config.get("workspace", "")
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
sorted_edge_key = sorted([edge_key[0], edge_key[1]]) sorted_edge_key = sorted([edge_key[0], edge_key[1]])
@ -2369,6 +2616,7 @@ async def merge_nodes_and_edges(
llm_response_cache, llm_response_cache,
added_entities, # Pass list to collect added entities added_entities, # Pass list to collect added entities
relation_chunks_storage, relation_chunks_storage,
entity_chunks_storage, # Add entity_chunks_storage parameter
) )
if edge_data is None: if edge_data is None:
@ -2525,6 +2773,14 @@ async def extract_entities(
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
text_chunks_storage: BaseKVStorage | None = None, text_chunks_storage: BaseKVStorage | None = None,
) -> list: ) -> list:
# Check for cancellation at the start of entity extraction
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during entity extraction"
)
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@ -2692,6 +2948,14 @@ async def extract_entities(
async def _process_with_semaphore(chunk): async def _process_with_semaphore(chunk):
async with semaphore: async with semaphore:
# Check for cancellation before processing chunk
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during chunk processing"
)
try: try:
return await _process_single_content(chunk) return await _process_single_content(chunk)
except Exception as e: except Exception as e:
@ -3189,10 +3453,10 @@ async def _perform_kg_search(
) )
query_embedding = None query_embedding = None
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
embedding_func_config = text_chunks_db.embedding_func actual_embedding_func = text_chunks_db.embedding_func
if embedding_func_config and embedding_func_config.func: if actual_embedding_func:
try: try:
query_embedding = await embedding_func_config.func([query]) query_embedding = await actual_embedding_func([query])
query_embedding = query_embedding[ query_embedding = query_embedding[
0 0
] # Extract first embedding from batch result ] # Extract first embedding from batch result
@ -3596,7 +3860,7 @@ async def _merge_all_chunks(
return merged_chunks return merged_chunks
async def _build_llm_context( async def _build_context_str(
entities_context: list[dict], entities_context: list[dict],
relations_context: list[dict], relations_context: list[dict],
merged_chunks: list[dict], merged_chunks: list[dict],
@ -3696,23 +3960,32 @@ async def _build_llm_context(
truncated_chunks truncated_chunks
) )
# Rebuild text_units_context with truncated chunks # Rebuild chunks_context with truncated chunks
# The actual tokens may be slightly less than available_chunk_tokens due to deduplication logic # The actual tokens may be slightly less than available_chunk_tokens due to deduplication logic
text_units_context = [] chunks_context = []
for i, chunk in enumerate(truncated_chunks): for i, chunk in enumerate(truncated_chunks):
text_units_context.append( chunks_context.append(
{ {
"reference_id": chunk["reference_id"], "reference_id": chunk["reference_id"],
"content": chunk["content"], "content": chunk["content"],
} }
) )
text_units_str = "\n".join(
json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context
)
reference_list_str = "\n".join(
f"[{ref['reference_id']}] {ref['file_path']}"
for ref in reference_list
if ref["reference_id"]
)
logger.info( logger.info(
f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks" f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(chunks_context)} chunks"
) )
# not necessary to use LLM to generate a response # not necessary to use LLM to generate a response
if not entities_context and not relations_context: if not entities_context and not relations_context and not chunks_context:
# Return empty raw data structure when no entities/relations # Return empty raw data structure when no entities/relations
empty_raw_data = convert_to_user_format( empty_raw_data = convert_to_user_format(
[], [],
@ -3743,15 +4016,6 @@ async def _build_llm_context(
if chunk_tracking_log: if chunk_tracking_log:
logger.info(f"Final chunks S+F/O: {' '.join(chunk_tracking_log)}") logger.info(f"Final chunks S+F/O: {' '.join(chunk_tracking_log)}")
text_units_str = "\n".join(
json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context
)
reference_list_str = "\n".join(
f"[{ref['reference_id']}] {ref['file_path']}"
for ref in reference_list
if ref["reference_id"]
)
result = kg_context_template.format( result = kg_context_template.format(
entities_str=entities_str, entities_str=entities_str,
relations_str=relations_str, relations_str=relations_str,
@ -3761,7 +4025,7 @@ async def _build_llm_context(
# Always return both context and complete data structure (unified approach) # Always return both context and complete data structure (unified approach)
logger.debug( logger.debug(
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks" f"[_build_context_str] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
) )
final_data = convert_to_user_format( final_data = convert_to_user_format(
entities_context, entities_context,
@ -3773,7 +4037,7 @@ async def _build_llm_context(
relation_id_to_original, relation_id_to_original,
) )
logger.debug( logger.debug(
f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks" f"[_build_context_str] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
) )
return result, final_data return result, final_data
@ -3850,8 +4114,8 @@ async def _build_query_context(
return None return None
# Stage 4: Build final LLM context with dynamic token processing # Stage 4: Build final LLM context with dynamic token processing
# _build_llm_context now always returns tuple[str, dict] # _build_context_str now always returns tuple[str, dict]
context, raw_data = await _build_llm_context( context, raw_data = await _build_context_str(
entities_context=truncation_result["entities_context"], entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"], relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks, merged_chunks=merged_chunks,
@ -4100,25 +4364,21 @@ async def _find_related_text_unit_from_entities(
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2) num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
# Get embedding function from global config # Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func actual_embedding_func = text_chunks_db.embedding_func
if not embedding_func_config: if not actual_embedding_func:
logger.warning("No embedding function found, falling back to WEIGHT method") logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
else: else:
try: try:
actual_embedding_func = embedding_func_config.func selected_chunk_ids = await pick_by_vector_similarity(
query=query,
selected_chunk_ids = None text_chunks_storage=text_chunks_db,
if actual_embedding_func: chunks_vdb=chunks_vdb,
selected_chunk_ids = await pick_by_vector_similarity( num_of_chunks=num_of_chunks,
query=query, entity_info=entities_with_chunks,
text_chunks_storage=text_chunks_db, embedding_func=actual_embedding_func,
chunks_vdb=chunks_vdb, query_embedding=query_embedding,
num_of_chunks=num_of_chunks, )
entity_info=entities_with_chunks,
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
if selected_chunk_ids == []: if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
@ -4393,24 +4653,21 @@ async def _find_related_text_unit_from_relations(
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2) num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
# Get embedding function from global config # Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func actual_embedding_func = text_chunks_db.embedding_func
if not embedding_func_config: if not actual_embedding_func:
logger.warning("No embedding function found, falling back to WEIGHT method") logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
else: else:
try: try:
actual_embedding_func = embedding_func_config.func selected_chunk_ids = await pick_by_vector_similarity(
query=query,
if actual_embedding_func: text_chunks_storage=text_chunks_db,
selected_chunk_ids = await pick_by_vector_similarity( chunks_vdb=chunks_vdb,
query=query, num_of_chunks=num_of_chunks,
text_chunks_storage=text_chunks_db, entity_info=relations_with_chunks,
chunks_vdb=chunks_vdb, embedding_func=actual_embedding_func,
num_of_chunks=num_of_chunks, query_embedding=query_embedding,
entity_info=relations_with_chunks, )
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
if selected_chunk_ids == []: if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
@ -4624,10 +4881,10 @@ async def naive_query(
"final_chunks_count": len(processed_chunks_with_ref_ids), "final_chunks_count": len(processed_chunks_with_ref_ids),
} }
# Build text_units_context from processed chunks with reference IDs # Build chunks_context from processed chunks with reference IDs
text_units_context = [] chunks_context = []
for i, chunk in enumerate(processed_chunks_with_ref_ids): for i, chunk in enumerate(processed_chunks_with_ref_ids):
text_units_context.append( chunks_context.append(
{ {
"reference_id": chunk["reference_id"], "reference_id": chunk["reference_id"],
"content": chunk["content"], "content": chunk["content"],
@ -4635,7 +4892,7 @@ async def naive_query(
) )
text_units_str = "\n".join( text_units_str = "\n".join(
json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context
) )
reference_list_str = "\n".join( reference_list_str = "\n".join(
f"[{ref['reference_id']}] {ref['file_path']}" f"[{ref['reference_id']}] {ref['file_path']}"