fix: sync core modules from upstream
This commit is contained in:
parent
f6f3ed93d3
commit
d1262e999d
2 changed files with 287 additions and 99 deletions
|
|
@ -8,6 +8,10 @@ import json_repair
|
||||||
from typing import Any, AsyncIterator, overload, Literal
|
from typing import Any, AsyncIterator, overload, Literal
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
|
from lightrag.exceptions import (
|
||||||
|
PipelineCancelledException,
|
||||||
|
ChunkTokenLimitExceededError,
|
||||||
|
)
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
logger,
|
logger,
|
||||||
compute_mdhash_id,
|
compute_mdhash_id,
|
||||||
|
|
@ -82,11 +86,11 @@ def _truncate_entity_identifier(
|
||||||
display_value = identifier[:limit]
|
display_value = identifier[:limit]
|
||||||
preview = identifier[:20] # Show first 20 characters as preview
|
preview = identifier[:20] # Show first 20 characters as preview
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"%s: %s exceeded %d characters (len: %d, preview: '%s...'",
|
"%s: %s len %d > %d chars (Name: '%s...')",
|
||||||
chunk_key,
|
chunk_key,
|
||||||
identifier_role,
|
identifier_role,
|
||||||
limit,
|
|
||||||
len(identifier),
|
len(identifier),
|
||||||
|
limit,
|
||||||
preview,
|
preview,
|
||||||
)
|
)
|
||||||
return display_value
|
return display_value
|
||||||
|
|
@ -97,8 +101,8 @@ def chunking_by_token_size(
|
||||||
content: str,
|
content: str,
|
||||||
split_by_character: str | None = None,
|
split_by_character: str | None = None,
|
||||||
split_by_character_only: bool = False,
|
split_by_character_only: bool = False,
|
||||||
overlap_token_size: int = 128,
|
chunk_overlap_token_size: int = 100,
|
||||||
max_token_size: int = 1024,
|
chunk_token_size: int = 1200,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
tokens = tokenizer.encode(content)
|
tokens = tokenizer.encode(content)
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
|
|
@ -108,19 +112,30 @@ def chunking_by_token_size(
|
||||||
if split_by_character_only:
|
if split_by_character_only:
|
||||||
for chunk in raw_chunks:
|
for chunk in raw_chunks:
|
||||||
_tokens = tokenizer.encode(chunk)
|
_tokens = tokenizer.encode(chunk)
|
||||||
|
if len(_tokens) > chunk_token_size:
|
||||||
|
logger.warning(
|
||||||
|
"Chunk split_by_character exceeds token limit: len=%d limit=%d",
|
||||||
|
len(_tokens),
|
||||||
|
chunk_token_size,
|
||||||
|
)
|
||||||
|
raise ChunkTokenLimitExceededError(
|
||||||
|
chunk_tokens=len(_tokens),
|
||||||
|
chunk_token_limit=chunk_token_size,
|
||||||
|
chunk_preview=chunk[:120],
|
||||||
|
)
|
||||||
new_chunks.append((len(_tokens), chunk))
|
new_chunks.append((len(_tokens), chunk))
|
||||||
else:
|
else:
|
||||||
for chunk in raw_chunks:
|
for chunk in raw_chunks:
|
||||||
_tokens = tokenizer.encode(chunk)
|
_tokens = tokenizer.encode(chunk)
|
||||||
if len(_tokens) > max_token_size:
|
if len(_tokens) > chunk_token_size:
|
||||||
for start in range(
|
for start in range(
|
||||||
0, len(_tokens), max_token_size - overlap_token_size
|
0, len(_tokens), chunk_token_size - chunk_overlap_token_size
|
||||||
):
|
):
|
||||||
chunk_content = tokenizer.decode(
|
chunk_content = tokenizer.decode(
|
||||||
_tokens[start : start + max_token_size]
|
_tokens[start : start + chunk_token_size]
|
||||||
)
|
)
|
||||||
new_chunks.append(
|
new_chunks.append(
|
||||||
(min(max_token_size, len(_tokens) - start), chunk_content)
|
(min(chunk_token_size, len(_tokens) - start), chunk_content)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_chunks.append((len(_tokens), chunk))
|
new_chunks.append((len(_tokens), chunk))
|
||||||
|
|
@ -134,12 +149,12 @@ def chunking_by_token_size(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for index, start in enumerate(
|
for index, start in enumerate(
|
||||||
range(0, len(tokens), max_token_size - overlap_token_size)
|
range(0, len(tokens), chunk_token_size - chunk_overlap_token_size)
|
||||||
):
|
):
|
||||||
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
|
chunk_content = tokenizer.decode(tokens[start : start + chunk_token_size])
|
||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"tokens": min(max_token_size, len(tokens) - start),
|
"tokens": min(chunk_token_size, len(tokens) - start),
|
||||||
"content": chunk_content.strip(),
|
"content": chunk_content.strip(),
|
||||||
"chunk_order_index": index,
|
"chunk_order_index": index,
|
||||||
}
|
}
|
||||||
|
|
@ -344,6 +359,20 @@ async def _summarize_descriptions(
|
||||||
llm_response_cache=llm_response_cache,
|
llm_response_cache=llm_response_cache,
|
||||||
cache_type="summary",
|
cache_type="summary",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check summary token length against embedding limit
|
||||||
|
embedding_token_limit = global_config.get("embedding_token_limit")
|
||||||
|
if embedding_token_limit is not None and summary:
|
||||||
|
tokenizer = global_config["tokenizer"]
|
||||||
|
summary_token_count = len(tokenizer.encode(summary))
|
||||||
|
threshold = int(embedding_token_limit * 0.9)
|
||||||
|
|
||||||
|
if summary_token_count > threshold:
|
||||||
|
logger.warning(
|
||||||
|
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
|
||||||
|
f"({embedding_token_limit}) for {description_type}: {description_name}"
|
||||||
|
)
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -988,13 +1017,13 @@ async def _process_extraction_result(
|
||||||
relationship_data["src_id"],
|
relationship_data["src_id"],
|
||||||
DEFAULT_ENTITY_NAME_MAX_LENGTH,
|
DEFAULT_ENTITY_NAME_MAX_LENGTH,
|
||||||
chunk_key,
|
chunk_key,
|
||||||
"Relationship source entity",
|
"Relation entity",
|
||||||
)
|
)
|
||||||
truncated_target = _truncate_entity_identifier(
|
truncated_target = _truncate_entity_identifier(
|
||||||
relationship_data["tgt_id"],
|
relationship_data["tgt_id"],
|
||||||
DEFAULT_ENTITY_NAME_MAX_LENGTH,
|
DEFAULT_ENTITY_NAME_MAX_LENGTH,
|
||||||
chunk_key,
|
chunk_key,
|
||||||
"Relationship target entity",
|
"Relation entity",
|
||||||
)
|
)
|
||||||
relationship_data["src_id"] = truncated_source
|
relationship_data["src_id"] = truncated_source
|
||||||
relationship_data["tgt_id"] = truncated_target
|
relationship_data["tgt_id"] = truncated_target
|
||||||
|
|
@ -1694,6 +1723,12 @@ async def _merge_nodes_then_upsert(
|
||||||
logger.error(f"Entity {entity_name} has no description")
|
logger.error(f"Entity {entity_name} has no description")
|
||||||
raise ValueError(f"Entity {entity_name} has no description")
|
raise ValueError(f"Entity {entity_name} has no description")
|
||||||
|
|
||||||
|
# Check for cancellation before LLM summary
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException("User cancelled during entity summary")
|
||||||
|
|
||||||
# 8. Get summary description an LLM usage status
|
# 8. Get summary description an LLM usage status
|
||||||
description, llm_was_used = await _handle_entity_relation_summary(
|
description, llm_was_used = await _handle_entity_relation_summary(
|
||||||
"Entity",
|
"Entity",
|
||||||
|
|
@ -2015,6 +2050,14 @@ async def _merge_edges_then_upsert(
|
||||||
logger.error(f"Relation {src_id}~{tgt_id} has no description")
|
logger.error(f"Relation {src_id}~{tgt_id} has no description")
|
||||||
raise ValueError(f"Relation {src_id}~{tgt_id} has no description")
|
raise ValueError(f"Relation {src_id}~{tgt_id} has no description")
|
||||||
|
|
||||||
|
# Check for cancellation before LLM summary
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException(
|
||||||
|
"User cancelled during relation summary"
|
||||||
|
)
|
||||||
|
|
||||||
# 8. Get summary description an LLM usage status
|
# 8. Get summary description an LLM usage status
|
||||||
description, llm_was_used = await _handle_entity_relation_summary(
|
description, llm_was_used = await _handle_entity_relation_summary(
|
||||||
"Relation",
|
"Relation",
|
||||||
|
|
@ -2396,6 +2439,12 @@ async def merge_nodes_and_edges(
|
||||||
file_path: File path for logging
|
file_path: File path for logging
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Check for cancellation at the start of merge
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException("User cancelled during merge phase")
|
||||||
|
|
||||||
# Collect all nodes and edges from all chunks
|
# Collect all nodes and edges from all chunks
|
||||||
all_nodes = defaultdict(list)
|
all_nodes = defaultdict(list)
|
||||||
all_edges = defaultdict(list)
|
all_edges = defaultdict(list)
|
||||||
|
|
@ -2432,6 +2481,14 @@ async def merge_nodes_and_edges(
|
||||||
|
|
||||||
async def _locked_process_entity_name(entity_name, entities):
|
async def _locked_process_entity_name(entity_name, entities):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
|
# Check for cancellation before processing entity
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException(
|
||||||
|
"User cancelled during entity merge"
|
||||||
|
)
|
||||||
|
|
||||||
workspace = global_config.get("workspace", "")
|
workspace = global_config.get("workspace", "")
|
||||||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||||||
async with get_storage_keyed_lock(
|
async with get_storage_keyed_lock(
|
||||||
|
|
@ -2454,9 +2511,7 @@ async def merge_nodes_and_edges(
|
||||||
return entity_data
|
return entity_data
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = (
|
error_msg = f"Error processing entity `{entity_name}`: {e}"
|
||||||
f"Critical error in entity processing for `{entity_name}`: {e}"
|
|
||||||
)
|
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
|
|
||||||
# Try to update pipeline status, but don't let status update failure affect main exception
|
# Try to update pipeline status, but don't let status update failure affect main exception
|
||||||
|
|
@ -2492,36 +2547,32 @@ async def merge_nodes_and_edges(
|
||||||
entity_tasks, return_when=asyncio.FIRST_EXCEPTION
|
entity_tasks, return_when=asyncio.FIRST_EXCEPTION
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if any task raised an exception and ensure all exceptions are retrieved
|
|
||||||
first_exception = None
|
first_exception = None
|
||||||
successful_results = []
|
processed_entities = []
|
||||||
|
|
||||||
for task in done:
|
for task in done:
|
||||||
try:
|
try:
|
||||||
exception = task.exception()
|
result = task.result()
|
||||||
if exception is not None:
|
except BaseException as e:
|
||||||
if first_exception is None:
|
|
||||||
first_exception = exception
|
|
||||||
else:
|
|
||||||
successful_results.append(task.result())
|
|
||||||
except Exception as e:
|
|
||||||
if first_exception is None:
|
if first_exception is None:
|
||||||
first_exception = e
|
first_exception = e
|
||||||
|
else:
|
||||||
|
processed_entities.append(result)
|
||||||
|
|
||||||
|
if pending:
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
pending_results = await asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
for result in pending_results:
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
if first_exception is None:
|
||||||
|
first_exception = result
|
||||||
|
else:
|
||||||
|
processed_entities.append(result)
|
||||||
|
|
||||||
# If any task failed, cancel all pending tasks and raise the first exception
|
|
||||||
if first_exception is not None:
|
if first_exception is not None:
|
||||||
# Cancel all pending tasks
|
|
||||||
for pending_task in pending:
|
|
||||||
pending_task.cancel()
|
|
||||||
# Wait for cancellation to complete
|
|
||||||
if pending:
|
|
||||||
await asyncio.wait(pending)
|
|
||||||
# Re-raise the first exception to notify the caller
|
|
||||||
raise first_exception
|
raise first_exception
|
||||||
|
|
||||||
# If all tasks completed successfully, collect results
|
|
||||||
processed_entities = [task.result() for task in entity_tasks]
|
|
||||||
|
|
||||||
# ===== Phase 2: Process all relationships concurrently =====
|
# ===== Phase 2: Process all relationships concurrently =====
|
||||||
log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})"
|
log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
|
|
@ -2531,6 +2582,14 @@ async def merge_nodes_and_edges(
|
||||||
|
|
||||||
async def _locked_process_edges(edge_key, edges):
|
async def _locked_process_edges(edge_key, edges):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
|
# Check for cancellation before processing edges
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException(
|
||||||
|
"User cancelled during relation merge"
|
||||||
|
)
|
||||||
|
|
||||||
workspace = global_config.get("workspace", "")
|
workspace = global_config.get("workspace", "")
|
||||||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||||||
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
||||||
|
|
@ -2566,7 +2625,7 @@ async def merge_nodes_and_edges(
|
||||||
return edge_data, added_entities
|
return edge_data, added_entities
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Critical error in relationship processing for `{sorted_edge_key}`: {e}"
|
error_msg = f"Error processing relation `{sorted_edge_key}`: {e}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
|
|
||||||
# Try to update pipeline status, but don't let status update failure affect main exception
|
# Try to update pipeline status, but don't let status update failure affect main exception
|
||||||
|
|
@ -2604,40 +2663,36 @@ async def merge_nodes_and_edges(
|
||||||
edge_tasks, return_when=asyncio.FIRST_EXCEPTION
|
edge_tasks, return_when=asyncio.FIRST_EXCEPTION
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if any task raised an exception and ensure all exceptions are retrieved
|
|
||||||
first_exception = None
|
first_exception = None
|
||||||
successful_results = []
|
|
||||||
|
|
||||||
for task in done:
|
for task in done:
|
||||||
try:
|
try:
|
||||||
exception = task.exception()
|
edge_data, added_entities = task.result()
|
||||||
if exception is not None:
|
except BaseException as e:
|
||||||
if first_exception is None:
|
|
||||||
first_exception = exception
|
|
||||||
else:
|
|
||||||
successful_results.append(task.result())
|
|
||||||
except Exception as e:
|
|
||||||
if first_exception is None:
|
if first_exception is None:
|
||||||
first_exception = e
|
first_exception = e
|
||||||
|
else:
|
||||||
|
if edge_data is not None:
|
||||||
|
processed_edges.append(edge_data)
|
||||||
|
all_added_entities.extend(added_entities)
|
||||||
|
|
||||||
|
if pending:
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
pending_results = await asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
for result in pending_results:
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
if first_exception is None:
|
||||||
|
first_exception = result
|
||||||
|
else:
|
||||||
|
edge_data, added_entities = result
|
||||||
|
if edge_data is not None:
|
||||||
|
processed_edges.append(edge_data)
|
||||||
|
all_added_entities.extend(added_entities)
|
||||||
|
|
||||||
# If any task failed, cancel all pending tasks and raise the first exception
|
|
||||||
if first_exception is not None:
|
if first_exception is not None:
|
||||||
# Cancel all pending tasks
|
|
||||||
for pending_task in pending:
|
|
||||||
pending_task.cancel()
|
|
||||||
# Wait for cancellation to complete
|
|
||||||
if pending:
|
|
||||||
await asyncio.wait(pending)
|
|
||||||
# Re-raise the first exception to notify the caller
|
|
||||||
raise first_exception
|
raise first_exception
|
||||||
|
|
||||||
# If all tasks completed successfully, collect results
|
|
||||||
for task in edge_tasks:
|
|
||||||
edge_data, added_entities = task.result()
|
|
||||||
if edge_data is not None:
|
|
||||||
processed_edges.append(edge_data)
|
|
||||||
all_added_entities.extend(added_entities)
|
|
||||||
|
|
||||||
# ===== Phase 3: Update full_entities and full_relations storage =====
|
# ===== Phase 3: Update full_entities and full_relations storage =====
|
||||||
if full_entities_storage and full_relations_storage and doc_id:
|
if full_entities_storage and full_relations_storage and doc_id:
|
||||||
try:
|
try:
|
||||||
|
|
@ -2718,6 +2773,14 @@ async def extract_entities(
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
text_chunks_storage: BaseKVStorage | None = None,
|
text_chunks_storage: BaseKVStorage | None = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
|
# Check for cancellation at the start of entity extraction
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException(
|
||||||
|
"User cancelled during entity extraction"
|
||||||
|
)
|
||||||
|
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
|
|
||||||
|
|
@ -2885,6 +2948,14 @@ async def extract_entities(
|
||||||
|
|
||||||
async def _process_with_semaphore(chunk):
|
async def _process_with_semaphore(chunk):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
|
# Check for cancellation before processing chunk
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException(
|
||||||
|
"User cancelled during chunk processing"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await _process_single_content(chunk)
|
return await _process_single_content(chunk)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -3382,10 +3453,10 @@ async def _perform_kg_search(
|
||||||
)
|
)
|
||||||
query_embedding = None
|
query_embedding = None
|
||||||
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
|
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
|
||||||
embedding_func_config = text_chunks_db.embedding_func
|
actual_embedding_func = text_chunks_db.embedding_func
|
||||||
if embedding_func_config and embedding_func_config.func:
|
if actual_embedding_func:
|
||||||
try:
|
try:
|
||||||
query_embedding = await embedding_func_config.func([query])
|
query_embedding = await actual_embedding_func([query])
|
||||||
query_embedding = query_embedding[
|
query_embedding = query_embedding[
|
||||||
0
|
0
|
||||||
] # Extract first embedding from batch result
|
] # Extract first embedding from batch result
|
||||||
|
|
@ -4293,25 +4364,21 @@ async def _find_related_text_unit_from_entities(
|
||||||
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
|
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
|
||||||
|
|
||||||
# Get embedding function from global config
|
# Get embedding function from global config
|
||||||
embedding_func_config = text_chunks_db.embedding_func
|
actual_embedding_func = text_chunks_db.embedding_func
|
||||||
if not embedding_func_config:
|
if not actual_embedding_func:
|
||||||
logger.warning("No embedding function found, falling back to WEIGHT method")
|
logger.warning("No embedding function found, falling back to WEIGHT method")
|
||||||
kg_chunk_pick_method = "WEIGHT"
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
actual_embedding_func = embedding_func_config.func
|
selected_chunk_ids = await pick_by_vector_similarity(
|
||||||
|
query=query,
|
||||||
selected_chunk_ids = None
|
text_chunks_storage=text_chunks_db,
|
||||||
if actual_embedding_func:
|
chunks_vdb=chunks_vdb,
|
||||||
selected_chunk_ids = await pick_by_vector_similarity(
|
num_of_chunks=num_of_chunks,
|
||||||
query=query,
|
entity_info=entities_with_chunks,
|
||||||
text_chunks_storage=text_chunks_db,
|
embedding_func=actual_embedding_func,
|
||||||
chunks_vdb=chunks_vdb,
|
query_embedding=query_embedding,
|
||||||
num_of_chunks=num_of_chunks,
|
)
|
||||||
entity_info=entities_with_chunks,
|
|
||||||
embedding_func=actual_embedding_func,
|
|
||||||
query_embedding=query_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
if selected_chunk_ids == []:
|
if selected_chunk_ids == []:
|
||||||
kg_chunk_pick_method = "WEIGHT"
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
|
|
@ -4586,24 +4653,21 @@ async def _find_related_text_unit_from_relations(
|
||||||
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
|
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
|
||||||
|
|
||||||
# Get embedding function from global config
|
# Get embedding function from global config
|
||||||
embedding_func_config = text_chunks_db.embedding_func
|
actual_embedding_func = text_chunks_db.embedding_func
|
||||||
if not embedding_func_config:
|
if not actual_embedding_func:
|
||||||
logger.warning("No embedding function found, falling back to WEIGHT method")
|
logger.warning("No embedding function found, falling back to WEIGHT method")
|
||||||
kg_chunk_pick_method = "WEIGHT"
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
actual_embedding_func = embedding_func_config.func
|
selected_chunk_ids = await pick_by_vector_similarity(
|
||||||
|
query=query,
|
||||||
if actual_embedding_func:
|
text_chunks_storage=text_chunks_db,
|
||||||
selected_chunk_ids = await pick_by_vector_similarity(
|
chunks_vdb=chunks_vdb,
|
||||||
query=query,
|
num_of_chunks=num_of_chunks,
|
||||||
text_chunks_storage=text_chunks_db,
|
entity_info=relations_with_chunks,
|
||||||
chunks_vdb=chunks_vdb,
|
embedding_func=actual_embedding_func,
|
||||||
num_of_chunks=num_of_chunks,
|
query_embedding=query_embedding,
|
||||||
entity_info=relations_with_chunks,
|
)
|
||||||
embedding_func=actual_embedding_func,
|
|
||||||
query_embedding=query_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
if selected_chunk_ids == []:
|
if selected_chunk_ids == []:
|
||||||
kg_chunk_pick_method = "WEIGHT"
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
import csv
|
import csv
|
||||||
|
|
@ -40,6 +42,35 @@ from lightrag.constants import (
|
||||||
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
||||||
|
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
||||||
|
|
||||||
|
|
||||||
|
class SafeStreamHandler(logging.StreamHandler):
|
||||||
|
"""StreamHandler that gracefully handles closed streams during shutdown.
|
||||||
|
|
||||||
|
This handler prevents "ValueError: I/O operation on closed file" errors
|
||||||
|
that can occur when pytest or other test frameworks close stdout/stderr
|
||||||
|
before Python's logging cleanup runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
"""Flush the stream, ignoring errors if the stream is closed."""
|
||||||
|
try:
|
||||||
|
super().flush()
|
||||||
|
except (ValueError, OSError):
|
||||||
|
# Stream is closed or otherwise unavailable, silently ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the handler, ignoring errors if the stream is already closed."""
|
||||||
|
try:
|
||||||
|
super().close()
|
||||||
|
except (ValueError, OSError):
|
||||||
|
# Stream is closed or otherwise unavailable, silently ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Initialize logger with basic configuration
|
# Initialize logger with basic configuration
|
||||||
logger = logging.getLogger("lightrag")
|
logger = logging.getLogger("lightrag")
|
||||||
logger.propagate = False # prevent log message send to root logger
|
logger.propagate = False # prevent log message send to root logger
|
||||||
|
|
@ -47,7 +78,7 @@ logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
# Add console handler if no handlers exist
|
# Add console handler if no handlers exist
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = SafeStreamHandler()
|
||||||
console_handler.setLevel(logging.INFO)
|
console_handler.setLevel(logging.INFO)
|
||||||
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
|
|
@ -56,8 +87,32 @@ if not logger.handlers:
|
||||||
# Set httpx logging level to WARNING
|
# Set httpx logging level to WARNING
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
|
||||||
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
def _patch_ascii_colors_console_handler() -> None:
|
||||||
|
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ascii_colors import ConsoleHandler
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
if getattr(ConsoleHandler, "_lightrag_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
original_handle_error = ConsoleHandler.handle_error
|
||||||
|
|
||||||
|
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
|
||||||
|
exc_type, _, _ = sys.exc_info()
|
||||||
|
if exc_type in (ValueError, OSError) and "close" in message.lower():
|
||||||
|
return
|
||||||
|
original_handle_error(self, message)
|
||||||
|
|
||||||
|
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
|
||||||
|
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
_patch_ascii_colors_console_handler()
|
||||||
|
|
||||||
|
|
||||||
# Global import for pypinyin with startup-time logging
|
# Global import for pypinyin with startup-time logging
|
||||||
try:
|
try:
|
||||||
|
|
@ -286,8 +341,8 @@ def setup_logger(
|
||||||
logger_instance.handlers = [] # Clear existing handlers
|
logger_instance.handlers = [] # Clear existing handlers
|
||||||
logger_instance.propagate = False
|
logger_instance.propagate = False
|
||||||
|
|
||||||
# Add console handler
|
# Add console handler with safe stream handling
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = SafeStreamHandler()
|
||||||
console_handler.setFormatter(simple_formatter)
|
console_handler.setFormatter(simple_formatter)
|
||||||
console_handler.setLevel(level)
|
console_handler.setLevel(level)
|
||||||
logger_instance.addHandler(console_handler)
|
logger_instance.addHandler(console_handler)
|
||||||
|
|
@ -950,7 +1005,76 @@ def priority_limit_async_func_call(
|
||||||
|
|
||||||
|
|
||||||
def wrap_embedding_func_with_attrs(**kwargs):
|
def wrap_embedding_func_with_attrs(**kwargs):
|
||||||
"""Wrap a function with attributes"""
|
"""Decorator to add embedding dimension and token limit attributes to embedding functions.
|
||||||
|
|
||||||
|
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
|
||||||
|
that automatically handles dimension parameter injection and attribute management.
|
||||||
|
|
||||||
|
WARNING: DO NOT apply this decorator to wrapper functions that call other
|
||||||
|
decorated embedding functions. This will cause double decoration and parameter
|
||||||
|
injection conflicts.
|
||||||
|
|
||||||
|
Correct usage patterns:
|
||||||
|
|
||||||
|
1. Direct implementation (decorated):
|
||||||
|
```python
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||||
|
async def my_embed(texts, embedding_dim=None):
|
||||||
|
# Direct implementation
|
||||||
|
return embeddings
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Wrapper calling decorated function (DO NOT decorate wrapper):
|
||||||
|
```python
|
||||||
|
# my_embed is already decorated above
|
||||||
|
|
||||||
|
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
|
||||||
|
# Must call .func to access unwrapped implementation
|
||||||
|
return await my_embed.func(texts, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Wrapper calling decorated function (properly decorated):
|
||||||
|
```python
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||||
|
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
|
||||||
|
# Calling .func avoids double decoration
|
||||||
|
return await my_embed.func(texts, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
The decorated function becomes an EmbeddingFunc instance with:
|
||||||
|
- embedding_dim: The embedding dimension
|
||||||
|
- max_token_size: Maximum token limit (optional)
|
||||||
|
- func: The original unwrapped function (access via .func)
|
||||||
|
- __call__: Wrapper that injects embedding_dim parameter
|
||||||
|
|
||||||
|
Double decoration causes:
|
||||||
|
- Double injection of embedding_dim parameter
|
||||||
|
- Incorrect parameter passing to the underlying implementation
|
||||||
|
- Runtime errors due to parameter conflicts
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_dim: The dimension of embedding vectors
|
||||||
|
max_token_size: Maximum number of tokens (optional)
|
||||||
|
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A decorator that wraps the function as an EmbeddingFunc instance
|
||||||
|
|
||||||
|
Example of correct wrapper implementation:
|
||||||
|
```python
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
|
@retry(...)
|
||||||
|
async def openai_embed(texts, ...):
|
||||||
|
# Base implementation
|
||||||
|
pass
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
|
||||||
|
async def azure_openai_embed(texts, ...):
|
||||||
|
# CRITICAL: Call .func to access unwrapped function
|
||||||
|
return await openai_embed.func(texts, ...) # ✅ Correct
|
||||||
|
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def final_decro(func) -> EmbeddingFunc:
|
def final_decro(func) -> EmbeddingFunc:
|
||||||
new_func = EmbeddingFunc(**kwargs, func=func)
|
new_func = EmbeddingFunc(**kwargs, func=func)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue