fix: sync core modules from upstream

This commit is contained in:
Raphaël MANSUY 2025-12-04 19:20:28 +08:00
parent f6f3ed93d3
commit d1262e999d
2 changed files with 287 additions and 99 deletions

View file

@ -8,6 +8,10 @@ import json_repair
from typing import Any, AsyncIterator, overload, Literal
from collections import Counter, defaultdict
from lightrag.exceptions import (
PipelineCancelledException,
ChunkTokenLimitExceededError,
)
from lightrag.utils import (
logger,
compute_mdhash_id,
@ -82,11 +86,11 @@ def _truncate_entity_identifier(
display_value = identifier[:limit]
preview = identifier[:20] # Show first 20 characters as preview
logger.warning(
"%s: %s exceeded %d characters (len: %d, preview: '%s...'",
"%s: %s len %d > %d chars (Name: '%s...')",
chunk_key,
identifier_role,
limit,
len(identifier),
limit,
preview,
)
return display_value
@ -97,8 +101,8 @@ def chunking_by_token_size(
content: str,
split_by_character: str | None = None,
split_by_character_only: bool = False,
overlap_token_size: int = 128,
max_token_size: int = 1024,
chunk_overlap_token_size: int = 100,
chunk_token_size: int = 1200,
) -> list[dict[str, Any]]:
tokens = tokenizer.encode(content)
results: list[dict[str, Any]] = []
@ -108,19 +112,30 @@ def chunking_by_token_size(
if split_by_character_only:
for chunk in raw_chunks:
_tokens = tokenizer.encode(chunk)
if len(_tokens) > chunk_token_size:
logger.warning(
"Chunk split_by_character exceeds token limit: len=%d limit=%d",
len(_tokens),
chunk_token_size,
)
raise ChunkTokenLimitExceededError(
chunk_tokens=len(_tokens),
chunk_token_limit=chunk_token_size,
chunk_preview=chunk[:120],
)
new_chunks.append((len(_tokens), chunk))
else:
for chunk in raw_chunks:
_tokens = tokenizer.encode(chunk)
if len(_tokens) > max_token_size:
if len(_tokens) > chunk_token_size:
for start in range(
0, len(_tokens), max_token_size - overlap_token_size
0, len(_tokens), chunk_token_size - chunk_overlap_token_size
):
chunk_content = tokenizer.decode(
_tokens[start : start + max_token_size]
_tokens[start : start + chunk_token_size]
)
new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content)
(min(chunk_token_size, len(_tokens) - start), chunk_content)
)
else:
new_chunks.append((len(_tokens), chunk))
@ -134,12 +149,12 @@ def chunking_by_token_size(
)
else:
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
range(0, len(tokens), chunk_token_size - chunk_overlap_token_size)
):
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
chunk_content = tokenizer.decode(tokens[start : start + chunk_token_size])
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
"tokens": min(chunk_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
}
@ -344,6 +359,20 @@ async def _summarize_descriptions(
llm_response_cache=llm_response_cache,
cache_type="summary",
)
# Check summary token length against embedding limit
embedding_token_limit = global_config.get("embedding_token_limit")
if embedding_token_limit is not None and summary:
tokenizer = global_config["tokenizer"]
summary_token_count = len(tokenizer.encode(summary))
threshold = int(embedding_token_limit * 0.9)
if summary_token_count > threshold:
logger.warning(
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
f"({embedding_token_limit}) for {description_type}: {description_name}"
)
return summary
@ -988,13 +1017,13 @@ async def _process_extraction_result(
relationship_data["src_id"],
DEFAULT_ENTITY_NAME_MAX_LENGTH,
chunk_key,
"Relationship source entity",
"Relation entity",
)
truncated_target = _truncate_entity_identifier(
relationship_data["tgt_id"],
DEFAULT_ENTITY_NAME_MAX_LENGTH,
chunk_key,
"Relationship target entity",
"Relation entity",
)
relationship_data["src_id"] = truncated_source
relationship_data["tgt_id"] = truncated_target
@ -1694,6 +1723,12 @@ async def _merge_nodes_then_upsert(
logger.error(f"Entity {entity_name} has no description")
raise ValueError(f"Entity {entity_name} has no description")
# Check for cancellation before LLM summary
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException("User cancelled during entity summary")
# 8. Get summary description an LLM usage status
description, llm_was_used = await _handle_entity_relation_summary(
"Entity",
@ -2015,6 +2050,14 @@ async def _merge_edges_then_upsert(
logger.error(f"Relation {src_id}~{tgt_id} has no description")
raise ValueError(f"Relation {src_id}~{tgt_id} has no description")
# Check for cancellation before LLM summary
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during relation summary"
)
# 8. Get summary description an LLM usage status
description, llm_was_used = await _handle_entity_relation_summary(
"Relation",
@ -2396,6 +2439,12 @@ async def merge_nodes_and_edges(
file_path: File path for logging
"""
# Check for cancellation at the start of merge
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException("User cancelled during merge phase")
# Collect all nodes and edges from all chunks
all_nodes = defaultdict(list)
all_edges = defaultdict(list)
@ -2432,6 +2481,14 @@ async def merge_nodes_and_edges(
async def _locked_process_entity_name(entity_name, entities):
async with semaphore:
# Check for cancellation before processing entity
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during entity merge"
)
workspace = global_config.get("workspace", "")
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
async with get_storage_keyed_lock(
@ -2454,9 +2511,7 @@ async def merge_nodes_and_edges(
return entity_data
except Exception as e:
error_msg = (
f"Critical error in entity processing for `{entity_name}`: {e}"
)
error_msg = f"Error processing entity `{entity_name}`: {e}"
logger.error(error_msg)
# Try to update pipeline status, but don't let status update failure affect main exception
@ -2492,36 +2547,32 @@ async def merge_nodes_and_edges(
entity_tasks, return_when=asyncio.FIRST_EXCEPTION
)
# Check if any task raised an exception and ensure all exceptions are retrieved
first_exception = None
successful_results = []
processed_entities = []
for task in done:
try:
exception = task.exception()
if exception is not None:
if first_exception is None:
first_exception = exception
else:
successful_results.append(task.result())
except Exception as e:
result = task.result()
except BaseException as e:
if first_exception is None:
first_exception = e
else:
processed_entities.append(result)
if pending:
for task in pending:
task.cancel()
pending_results = await asyncio.gather(*pending, return_exceptions=True)
for result in pending_results:
if isinstance(result, BaseException):
if first_exception is None:
first_exception = result
else:
processed_entities.append(result)
# If any task failed, cancel all pending tasks and raise the first exception
if first_exception is not None:
# Cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Re-raise the first exception to notify the caller
raise first_exception
# If all tasks completed successfully, collect results
processed_entities = [task.result() for task in entity_tasks]
# ===== Phase 2: Process all relationships concurrently =====
log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})"
logger.info(log_message)
@ -2531,6 +2582,14 @@ async def merge_nodes_and_edges(
async def _locked_process_edges(edge_key, edges):
async with semaphore:
# Check for cancellation before processing edges
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during relation merge"
)
workspace = global_config.get("workspace", "")
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
@ -2566,7 +2625,7 @@ async def merge_nodes_and_edges(
return edge_data, added_entities
except Exception as e:
error_msg = f"Critical error in relationship processing for `{sorted_edge_key}`: {e}"
error_msg = f"Error processing relation `{sorted_edge_key}`: {e}"
logger.error(error_msg)
# Try to update pipeline status, but don't let status update failure affect main exception
@ -2604,40 +2663,36 @@ async def merge_nodes_and_edges(
edge_tasks, return_when=asyncio.FIRST_EXCEPTION
)
# Check if any task raised an exception and ensure all exceptions are retrieved
first_exception = None
successful_results = []
for task in done:
try:
exception = task.exception()
if exception is not None:
if first_exception is None:
first_exception = exception
else:
successful_results.append(task.result())
except Exception as e:
edge_data, added_entities = task.result()
except BaseException as e:
if first_exception is None:
first_exception = e
else:
if edge_data is not None:
processed_edges.append(edge_data)
all_added_entities.extend(added_entities)
if pending:
for task in pending:
task.cancel()
pending_results = await asyncio.gather(*pending, return_exceptions=True)
for result in pending_results:
if isinstance(result, BaseException):
if first_exception is None:
first_exception = result
else:
edge_data, added_entities = result
if edge_data is not None:
processed_edges.append(edge_data)
all_added_entities.extend(added_entities)
# If any task failed, cancel all pending tasks and raise the first exception
if first_exception is not None:
# Cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Re-raise the first exception to notify the caller
raise first_exception
# If all tasks completed successfully, collect results
for task in edge_tasks:
edge_data, added_entities = task.result()
if edge_data is not None:
processed_edges.append(edge_data)
all_added_entities.extend(added_entities)
# ===== Phase 3: Update full_entities and full_relations storage =====
if full_entities_storage and full_relations_storage and doc_id:
try:
@ -2718,6 +2773,14 @@ async def extract_entities(
llm_response_cache: BaseKVStorage | None = None,
text_chunks_storage: BaseKVStorage | None = None,
) -> list:
# Check for cancellation at the start of entity extraction
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during entity extraction"
)
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@ -2885,6 +2948,14 @@ async def extract_entities(
async def _process_with_semaphore(chunk):
async with semaphore:
# Check for cancellation before processing chunk
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
if pipeline_status.get("cancellation_requested", False):
raise PipelineCancelledException(
"User cancelled during chunk processing"
)
try:
return await _process_single_content(chunk)
except Exception as e:
@ -3382,10 +3453,10 @@ async def _perform_kg_search(
)
query_embedding = None
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
embedding_func_config = text_chunks_db.embedding_func
if embedding_func_config and embedding_func_config.func:
actual_embedding_func = text_chunks_db.embedding_func
if actual_embedding_func:
try:
query_embedding = await embedding_func_config.func([query])
query_embedding = await actual_embedding_func([query])
query_embedding = query_embedding[
0
] # Extract first embedding from batch result
@ -4293,25 +4364,21 @@ async def _find_related_text_unit_from_entities(
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
# Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func
if not embedding_func_config:
actual_embedding_func = text_chunks_db.embedding_func
if not actual_embedding_func:
logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT"
else:
try:
actual_embedding_func = embedding_func_config.func
selected_chunk_ids = None
if actual_embedding_func:
selected_chunk_ids = await pick_by_vector_similarity(
query=query,
text_chunks_storage=text_chunks_db,
chunks_vdb=chunks_vdb,
num_of_chunks=num_of_chunks,
entity_info=entities_with_chunks,
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
selected_chunk_ids = await pick_by_vector_similarity(
query=query,
text_chunks_storage=text_chunks_db,
chunks_vdb=chunks_vdb,
num_of_chunks=num_of_chunks,
entity_info=entities_with_chunks,
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT"
@ -4586,24 +4653,21 @@ async def _find_related_text_unit_from_relations(
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
# Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func
if not embedding_func_config:
actual_embedding_func = text_chunks_db.embedding_func
if not actual_embedding_func:
logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT"
else:
try:
actual_embedding_func = embedding_func_config.func
if actual_embedding_func:
selected_chunk_ids = await pick_by_vector_similarity(
query=query,
text_chunks_storage=text_chunks_db,
chunks_vdb=chunks_vdb,
num_of_chunks=num_of_chunks,
entity_info=relations_with_chunks,
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
selected_chunk_ids = await pick_by_vector_similarity(
query=query,
text_chunks_storage=text_chunks_db,
chunks_vdb=chunks_vdb,
num_of_chunks=num_of_chunks,
entity_info=relations_with_chunks,
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT"

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import weakref
import sys
import asyncio
import html
import csv
@ -40,6 +42,35 @@ from lightrag.constants import (
SOURCE_IDS_LIMIT_METHOD_FIFO,
)
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
class SafeStreamHandler(logging.StreamHandler):
"""StreamHandler that gracefully handles closed streams during shutdown.
This handler prevents "ValueError: I/O operation on closed file" errors
that can occur when pytest or other test frameworks close stdout/stderr
before Python's logging cleanup runs.
"""
def flush(self):
"""Flush the stream, ignoring errors if the stream is closed."""
try:
super().flush()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
def close(self):
"""Close the handler, ignoring errors if the stream is already closed."""
try:
super().close()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
# Initialize logger with basic configuration
logger = logging.getLogger("lightrag")
logger.propagate = False # prevent log message send to root logger
@ -47,7 +78,7 @@ logger.setLevel(logging.INFO)
# Add console handler if no handlers exist
if not logger.handlers:
console_handler = logging.StreamHandler()
console_handler = SafeStreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(levelname)s: %(message)s")
console_handler.setFormatter(formatter)
@ -56,8 +87,32 @@ if not logger.handlers:
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
def _patch_ascii_colors_console_handler() -> None:
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
try:
from ascii_colors import ConsoleHandler
except ImportError:
return
if getattr(ConsoleHandler, "_lightrag_patched", False):
return
original_handle_error = ConsoleHandler.handle_error
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
exc_type, _, _ = sys.exc_info()
if exc_type in (ValueError, OSError) and "close" in message.lower():
return
original_handle_error(self, message)
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
_patch_ascii_colors_console_handler()
# Global import for pypinyin with startup-time logging
try:
@ -286,8 +341,8 @@ def setup_logger(
logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False
# Add console handler
console_handler = logging.StreamHandler()
# Add console handler with safe stream handling
console_handler = SafeStreamHandler()
console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level)
logger_instance.addHandler(console_handler)
@ -950,7 +1005,76 @@ def priority_limit_async_func_call(
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""
"""Decorator to add embedding dimension and token limit attributes to embedding functions.
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
that automatically handles dimension parameter injection and attribute management.
WARNING: DO NOT apply this decorator to wrapper functions that call other
decorated embedding functions. This will cause double decoration and parameter
injection conflicts.
Correct usage patterns:
1. Direct implementation (decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_embed(texts, embedding_dim=None):
# Direct implementation
return embeddings
```
2. Wrapper calling decorated function (DO NOT decorate wrapper):
```python
# my_embed is already decorated above
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
# Must call .func to access unwrapped implementation
return await my_embed.func(texts, **kwargs)
```
3. Wrapper calling decorated function (properly decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
# Calling .func avoids double decoration
return await my_embed.func(texts, **kwargs)
```
The decorated function becomes an EmbeddingFunc instance with:
- embedding_dim: The embedding dimension
- max_token_size: Maximum token limit (optional)
- func: The original unwrapped function (access via .func)
- __call__: Wrapper that injects embedding_dim parameter
Double decoration causes:
- Double injection of embedding_dim parameter
- Incorrect parameter passing to the underlying implementation
- Runtime errors due to parameter conflicts
Args:
embedding_dim: The dimension of embedding vectors
max_token_size: Maximum number of tokens (optional)
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
Returns:
A decorator that wraps the function as an EmbeddingFunc instance
Example of correct wrapper implementation:
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(...)
async def openai_embed(texts, ...):
# Base implementation
pass
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
async def azure_openai_embed(texts, ...):
# CRITICAL: Call .func to access unwrapped function
return await openai_embed.func(texts, ...) # ✅ Correct
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
```
"""
def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func)