From d1262e999d22cffa5693527a1a15160408e52568 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:20:28 +0800 Subject: [PATCH] fix: sync core modules from upstream --- lightrag/operate.py | 250 ++++++++++++++++++++++++++++---------------- lightrag/utils.py | 136 ++++++++++++++++++++++-- 2 files changed, 287 insertions(+), 99 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 6185dbeb..c6724974 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -8,6 +8,10 @@ import json_repair from typing import Any, AsyncIterator, overload, Literal from collections import Counter, defaultdict +from lightrag.exceptions import ( + PipelineCancelledException, + ChunkTokenLimitExceededError, +) from lightrag.utils import ( logger, compute_mdhash_id, @@ -82,11 +86,11 @@ def _truncate_entity_identifier( display_value = identifier[:limit] preview = identifier[:20] # Show first 20 characters as preview logger.warning( - "%s: %s exceeded %d characters (len: %d, preview: '%s...'", + "%s: %s len %d > %d chars (Name: '%s...')", chunk_key, identifier_role, - limit, len(identifier), + limit, preview, ) return display_value @@ -97,8 +101,8 @@ def chunking_by_token_size( content: str, split_by_character: str | None = None, split_by_character_only: bool = False, - overlap_token_size: int = 128, - max_token_size: int = 1024, + chunk_overlap_token_size: int = 100, + chunk_token_size: int = 1200, ) -> list[dict[str, Any]]: tokens = tokenizer.encode(content) results: list[dict[str, Any]] = [] @@ -108,19 +112,30 @@ def chunking_by_token_size( if split_by_character_only: for chunk in raw_chunks: _tokens = tokenizer.encode(chunk) + if len(_tokens) > chunk_token_size: + logger.warning( + "Chunk split_by_character exceeds token limit: len=%d limit=%d", + len(_tokens), + chunk_token_size, + ) + raise ChunkTokenLimitExceededError( + chunk_tokens=len(_tokens), + chunk_token_limit=chunk_token_size, + chunk_preview=chunk[:120], + ) new_chunks.append((len(_tokens), chunk)) else: for chunk in raw_chunks: _tokens = tokenizer.encode(chunk) - if len(_tokens) > max_token_size: + if len(_tokens) > chunk_token_size: for start in range( - 0, len(_tokens), max_token_size - overlap_token_size + 0, len(_tokens), chunk_token_size - chunk_overlap_token_size ): chunk_content = tokenizer.decode( - _tokens[start : start + max_token_size] + _tokens[start : start + chunk_token_size] ) new_chunks.append( - (min(max_token_size, len(_tokens) - start), chunk_content) + (min(chunk_token_size, len(_tokens) - start), chunk_content) ) else: new_chunks.append((len(_tokens), chunk)) @@ -134,12 +149,12 @@ def chunking_by_token_size( ) else: for index, start in enumerate( - range(0, len(tokens), max_token_size - overlap_token_size) + range(0, len(tokens), chunk_token_size - chunk_overlap_token_size) ): - chunk_content = tokenizer.decode(tokens[start : start + max_token_size]) + chunk_content = tokenizer.decode(tokens[start : start + chunk_token_size]) results.append( { - "tokens": min(max_token_size, len(tokens) - start), + "tokens": min(chunk_token_size, len(tokens) - start), "content": chunk_content.strip(), "chunk_order_index": index, } @@ -344,6 +359,20 @@ async def _summarize_descriptions( llm_response_cache=llm_response_cache, cache_type="summary", ) + + # Check summary token length against embedding limit + embedding_token_limit = global_config.get("embedding_token_limit") + if embedding_token_limit is not None and summary: + tokenizer = global_config["tokenizer"] + summary_token_count = len(tokenizer.encode(summary)) + threshold = int(embedding_token_limit * 0.9) + + if summary_token_count > threshold: + logger.warning( + f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit " + f"({embedding_token_limit}) for {description_type}: {description_name}" + ) + return summary @@ -988,13 +1017,13 @@ async def _process_extraction_result( relationship_data["src_id"], DEFAULT_ENTITY_NAME_MAX_LENGTH, chunk_key, - "Relationship source entity", + "Relation entity", ) truncated_target = _truncate_entity_identifier( relationship_data["tgt_id"], DEFAULT_ENTITY_NAME_MAX_LENGTH, chunk_key, - "Relationship target entity", + "Relation entity", ) relationship_data["src_id"] = truncated_source relationship_data["tgt_id"] = truncated_target @@ -1694,6 +1723,12 @@ async def _merge_nodes_then_upsert( logger.error(f"Entity {entity_name} has no description") raise ValueError(f"Entity {entity_name} has no description") + # Check for cancellation before LLM summary + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled during entity summary") + # 8. Get summary description an LLM usage status description, llm_was_used = await _handle_entity_relation_summary( "Entity", @@ -2015,6 +2050,14 @@ async def _merge_edges_then_upsert( logger.error(f"Relation {src_id}~{tgt_id} has no description") raise ValueError(f"Relation {src_id}~{tgt_id} has no description") + # Check for cancellation before LLM summary + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during relation summary" + ) + # 8. Get summary description an LLM usage status description, llm_was_used = await _handle_entity_relation_summary( "Relation", @@ -2396,6 +2439,12 @@ async def merge_nodes_and_edges( file_path: File path for logging """ + # Check for cancellation at the start of merge + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled during merge phase") + # Collect all nodes and edges from all chunks all_nodes = defaultdict(list) all_edges = defaultdict(list) @@ -2432,6 +2481,14 @@ async def merge_nodes_and_edges( async def _locked_process_entity_name(entity_name, entities): async with semaphore: + # Check for cancellation before processing entity + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during entity merge" + ) + workspace = global_config.get("workspace", "") namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" async with get_storage_keyed_lock( @@ -2454,9 +2511,7 @@ async def merge_nodes_and_edges( return entity_data except Exception as e: - error_msg = ( - f"Critical error in entity processing for `{entity_name}`: {e}" - ) + error_msg = f"Error processing entity `{entity_name}`: {e}" logger.error(error_msg) # Try to update pipeline status, but don't let status update failure affect main exception @@ -2492,36 +2547,32 @@ async def merge_nodes_and_edges( entity_tasks, return_when=asyncio.FIRST_EXCEPTION ) - # Check if any task raised an exception and ensure all exceptions are retrieved first_exception = None - successful_results = [] + processed_entities = [] for task in done: try: - exception = task.exception() - if exception is not None: - if first_exception is None: - first_exception = exception - else: - successful_results.append(task.result()) - except Exception as e: + result = task.result() + except BaseException as e: if first_exception is None: first_exception = e + else: + processed_entities.append(result) + + if pending: + for task in pending: + task.cancel() + pending_results = await asyncio.gather(*pending, return_exceptions=True) + for result in pending_results: + if isinstance(result, BaseException): + if first_exception is None: + first_exception = result + else: + processed_entities.append(result) - # If any task failed, cancel all pending tasks and raise the first exception if first_exception is not None: - # Cancel all pending tasks - for pending_task in pending: - pending_task.cancel() - # Wait for cancellation to complete - if pending: - await asyncio.wait(pending) - # Re-raise the first exception to notify the caller raise first_exception - # If all tasks completed successfully, collect results - processed_entities = [task.result() for task in entity_tasks] - # ===== Phase 2: Process all relationships concurrently ===== log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})" logger.info(log_message) @@ -2531,6 +2582,14 @@ async def merge_nodes_and_edges( async def _locked_process_edges(edge_key, edges): async with semaphore: + # Check for cancellation before processing edges + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during relation merge" + ) + workspace = global_config.get("workspace", "") namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" sorted_edge_key = sorted([edge_key[0], edge_key[1]]) @@ -2566,7 +2625,7 @@ async def merge_nodes_and_edges( return edge_data, added_entities except Exception as e: - error_msg = f"Critical error in relationship processing for `{sorted_edge_key}`: {e}" + error_msg = f"Error processing relation `{sorted_edge_key}`: {e}" logger.error(error_msg) # Try to update pipeline status, but don't let status update failure affect main exception @@ -2604,40 +2663,36 @@ async def merge_nodes_and_edges( edge_tasks, return_when=asyncio.FIRST_EXCEPTION ) - # Check if any task raised an exception and ensure all exceptions are retrieved first_exception = None - successful_results = [] for task in done: try: - exception = task.exception() - if exception is not None: - if first_exception is None: - first_exception = exception - else: - successful_results.append(task.result()) - except Exception as e: + edge_data, added_entities = task.result() + except BaseException as e: if first_exception is None: first_exception = e + else: + if edge_data is not None: + processed_edges.append(edge_data) + all_added_entities.extend(added_entities) + + if pending: + for task in pending: + task.cancel() + pending_results = await asyncio.gather(*pending, return_exceptions=True) + for result in pending_results: + if isinstance(result, BaseException): + if first_exception is None: + first_exception = result + else: + edge_data, added_entities = result + if edge_data is not None: + processed_edges.append(edge_data) + all_added_entities.extend(added_entities) - # If any task failed, cancel all pending tasks and raise the first exception if first_exception is not None: - # Cancel all pending tasks - for pending_task in pending: - pending_task.cancel() - # Wait for cancellation to complete - if pending: - await asyncio.wait(pending) - # Re-raise the first exception to notify the caller raise first_exception - # If all tasks completed successfully, collect results - for task in edge_tasks: - edge_data, added_entities = task.result() - if edge_data is not None: - processed_edges.append(edge_data) - all_added_entities.extend(added_entities) - # ===== Phase 3: Update full_entities and full_relations storage ===== if full_entities_storage and full_relations_storage and doc_id: try: @@ -2718,6 +2773,14 @@ async def extract_entities( llm_response_cache: BaseKVStorage | None = None, text_chunks_storage: BaseKVStorage | None = None, ) -> list: + # Check for cancellation at the start of entity extraction + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during entity extraction" + ) + use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -2885,6 +2948,14 @@ async def extract_entities( async def _process_with_semaphore(chunk): async with semaphore: + # Check for cancellation before processing chunk + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during chunk processing" + ) + try: return await _process_single_content(chunk) except Exception as e: @@ -3382,10 +3453,10 @@ async def _perform_kg_search( ) query_embedding = None if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): - embedding_func_config = text_chunks_db.embedding_func - if embedding_func_config and embedding_func_config.func: + actual_embedding_func = text_chunks_db.embedding_func + if actual_embedding_func: try: - query_embedding = await embedding_func_config.func([query]) + query_embedding = await actual_embedding_func([query]) query_embedding = query_embedding[ 0 ] # Extract first embedding from batch result @@ -4293,25 +4364,21 @@ async def _find_related_text_unit_from_entities( num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2) # Get embedding function from global config - embedding_func_config = text_chunks_db.embedding_func - if not embedding_func_config: + actual_embedding_func = text_chunks_db.embedding_func + if not actual_embedding_func: logger.warning("No embedding function found, falling back to WEIGHT method") kg_chunk_pick_method = "WEIGHT" else: try: - actual_embedding_func = embedding_func_config.func - - selected_chunk_ids = None - if actual_embedding_func: - selected_chunk_ids = await pick_by_vector_similarity( - query=query, - text_chunks_storage=text_chunks_db, - chunks_vdb=chunks_vdb, - num_of_chunks=num_of_chunks, - entity_info=entities_with_chunks, - embedding_func=actual_embedding_func, - query_embedding=query_embedding, - ) + selected_chunk_ids = await pick_by_vector_similarity( + query=query, + text_chunks_storage=text_chunks_db, + chunks_vdb=chunks_vdb, + num_of_chunks=num_of_chunks, + entity_info=entities_with_chunks, + embedding_func=actual_embedding_func, + query_embedding=query_embedding, + ) if selected_chunk_ids == []: kg_chunk_pick_method = "WEIGHT" @@ -4586,24 +4653,21 @@ async def _find_related_text_unit_from_relations( num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2) # Get embedding function from global config - embedding_func_config = text_chunks_db.embedding_func - if not embedding_func_config: + actual_embedding_func = text_chunks_db.embedding_func + if not actual_embedding_func: logger.warning("No embedding function found, falling back to WEIGHT method") kg_chunk_pick_method = "WEIGHT" else: try: - actual_embedding_func = embedding_func_config.func - - if actual_embedding_func: - selected_chunk_ids = await pick_by_vector_similarity( - query=query, - text_chunks_storage=text_chunks_db, - chunks_vdb=chunks_vdb, - num_of_chunks=num_of_chunks, - entity_info=relations_with_chunks, - embedding_func=actual_embedding_func, - query_embedding=query_embedding, - ) + selected_chunk_ids = await pick_by_vector_similarity( + query=query, + text_chunks_storage=text_chunks_db, + chunks_vdb=chunks_vdb, + num_of_chunks=num_of_chunks, + entity_info=relations_with_chunks, + embedding_func=actual_embedding_func, + query_embedding=query_embedding, + ) if selected_chunk_ids == []: kg_chunk_pick_method = "WEIGHT" diff --git a/lightrag/utils.py b/lightrag/utils.py index 8c9b7776..65c1e4bc 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,6 +1,8 @@ from __future__ import annotations import weakref +import sys + import asyncio import html import csv @@ -40,6 +42,35 @@ from lightrag.constants import ( SOURCE_IDS_LIMIT_METHOD_FIFO, ) +# Precompile regex pattern for JSON sanitization (module-level, compiled once) +_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]") + + +class SafeStreamHandler(logging.StreamHandler): + """StreamHandler that gracefully handles closed streams during shutdown. + + This handler prevents "ValueError: I/O operation on closed file" errors + that can occur when pytest or other test frameworks close stdout/stderr + before Python's logging cleanup runs. + """ + + def flush(self): + """Flush the stream, ignoring errors if the stream is closed.""" + try: + super().flush() + except (ValueError, OSError): + # Stream is closed or otherwise unavailable, silently ignore + pass + + def close(self): + """Close the handler, ignoring errors if the stream is already closed.""" + try: + super().close() + except (ValueError, OSError): + # Stream is closed or otherwise unavailable, silently ignore + pass + + # Initialize logger with basic configuration logger = logging.getLogger("lightrag") logger.propagate = False # prevent log message send to root logger @@ -47,7 +78,7 @@ logger.setLevel(logging.INFO) # Add console handler if no handlers exist if not logger.handlers: - console_handler = logging.StreamHandler() + console_handler = SafeStreamHandler() console_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(levelname)s: %(message)s") console_handler.setFormatter(formatter) @@ -56,8 +87,32 @@ if not logger.handlers: # Set httpx logging level to WARNING logging.getLogger("httpx").setLevel(logging.WARNING) -# Precompile regex pattern for JSON sanitization (module-level, compiled once) -_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]") + +def _patch_ascii_colors_console_handler() -> None: + """Prevent ascii_colors from printing flush errors during interpreter exit.""" + + try: + from ascii_colors import ConsoleHandler + except ImportError: + return + + if getattr(ConsoleHandler, "_lightrag_patched", False): + return + + original_handle_error = ConsoleHandler.handle_error + + def _safe_handle_error(self, message: str) -> None: # type: ignore[override] + exc_type, _, _ = sys.exc_info() + if exc_type in (ValueError, OSError) and "close" in message.lower(): + return + original_handle_error(self, message) + + ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment] + ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined] + + +_patch_ascii_colors_console_handler() + # Global import for pypinyin with startup-time logging try: @@ -286,8 +341,8 @@ def setup_logger( logger_instance.handlers = [] # Clear existing handlers logger_instance.propagate = False - # Add console handler - console_handler = logging.StreamHandler() + # Add console handler with safe stream handling + console_handler = SafeStreamHandler() console_handler.setFormatter(simple_formatter) console_handler.setLevel(level) logger_instance.addHandler(console_handler) @@ -950,7 +1005,76 @@ def priority_limit_async_func_call( def wrap_embedding_func_with_attrs(**kwargs): - """Wrap a function with attributes""" + """Decorator to add embedding dimension and token limit attributes to embedding functions. + + This decorator wraps an async embedding function and returns an EmbeddingFunc instance + that automatically handles dimension parameter injection and attribute management. + + WARNING: DO NOT apply this decorator to wrapper functions that call other + decorated embedding functions. This will cause double decoration and parameter + injection conflicts. + + Correct usage patterns: + + 1. Direct implementation (decorated): + ```python + @wrap_embedding_func_with_attrs(embedding_dim=1536) + async def my_embed(texts, embedding_dim=None): + # Direct implementation + return embeddings + ``` + + 2. Wrapper calling decorated function (DO NOT decorate wrapper): + ```python + # my_embed is already decorated above + + async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this! + # Must call .func to access unwrapped implementation + return await my_embed.func(texts, **kwargs) + ``` + + 3. Wrapper calling decorated function (properly decorated): + ```python + @wrap_embedding_func_with_attrs(embedding_dim=1536) + async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func + # Calling .func avoids double decoration + return await my_embed.func(texts, **kwargs) + ``` + + The decorated function becomes an EmbeddingFunc instance with: + - embedding_dim: The embedding dimension + - max_token_size: Maximum token limit (optional) + - func: The original unwrapped function (access via .func) + - __call__: Wrapper that injects embedding_dim parameter + + Double decoration causes: + - Double injection of embedding_dim parameter + - Incorrect parameter passing to the underlying implementation + - Runtime errors due to parameter conflicts + + Args: + embedding_dim: The dimension of embedding vectors + max_token_size: Maximum number of tokens (optional) + send_dimensions: Whether to inject embedding_dim as a keyword argument (optional) + + Returns: + A decorator that wraps the function as an EmbeddingFunc instance + + Example of correct wrapper implementation: + ```python + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) + @retry(...) + async def openai_embed(texts, ...): + # Base implementation + pass + + @wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here! + async def azure_openai_embed(texts, ...): + # CRITICAL: Call .func to access unwrapped function + return await openai_embed.func(texts, ...) # ✅ Correct + # return await openai_embed(texts, ...) # ❌ Wrong - double decoration! + ``` + """ def final_decro(func) -> EmbeddingFunc: new_func = EmbeddingFunc(**kwargs, func=func)