This commit is contained in:
Raphaël MANSUY 2025-12-04 19:18:16 +08:00
parent 0a6e4616b2
commit d85c5a5875
3 changed files with 446 additions and 1359 deletions

View file

@ -73,6 +73,8 @@ ENABLE_LLM_CACHE=true
# MAX_RELATION_TOKENS=8000 # MAX_RELATION_TOKENS=8000
### control the maximum tokens send to LLM (include entities, relations and chunks) ### control the maximum tokens send to LLM (include entities, relations and chunks)
# MAX_TOTAL_TOKENS=30000 # MAX_TOTAL_TOKENS=30000
### control the maximum chunk_ids stored in vector db
# MAX_CHUNK_IDS_PER_ENTITY=500
### maximum number of related chunks per source entity or relation ### maximum number of related chunks per source entity or relation
### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph) ### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph)

File diff suppressed because it is too large Load diff

View file

@ -15,17 +15,7 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import ( from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional
Any,
Protocol,
Callable,
TYPE_CHECKING,
List,
Optional,
Iterable,
Sequence,
Collection,
)
import numpy as np import numpy as np
from dotenv import load_dotenv from dotenv import load_dotenv
@ -35,9 +25,8 @@ from lightrag.constants import (
DEFAULT_LOG_FILENAME, DEFAULT_LOG_FILENAME,
GRAPH_FIELD_SEP, GRAPH_FIELD_SEP,
DEFAULT_MAX_TOTAL_TOKENS, DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_SOURCE_IDS_LIMIT_METHOD, DEFAULT_MAX_FILE_PATH_LENGTH,
VALID_SOURCE_IDS_LIMIT_METHODS, DEFAULT_MAX_CHUNK_IDS_PER_ENTITY,
SOURCE_IDS_LIMIT_METHOD_FIFO,
) )
# Initialize logger with basic configuration # Initialize logger with basic configuration
@ -353,29 +342,8 @@ class EmbeddingFunc:
embedding_dim: int embedding_dim: int
func: callable func: callable
max_token_size: int | None = None # deprecated keep it for compatible only max_token_size: int | None = None # deprecated keep it for compatible only
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
async def __call__(self, *args, **kwargs) -> np.ndarray: async def __call__(self, *args, **kwargs) -> np.ndarray:
# Only inject embedding_dim when send_dimensions is True
if self.send_dimensions:
# Check if user provided embedding_dim parameter
if "embedding_dim" in kwargs:
user_provided_dim = kwargs["embedding_dim"]
# If user's value differs from class attribute, output warning
if (
user_provided_dim is not None
and user_provided_dim != self.embedding_dim
):
logger.warning(
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
f"using declared embedding_dim={self.embedding_dim} from decorator"
)
# Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim
return await self.func(*args, **kwargs) return await self.func(*args, **kwargs)
@ -927,45 +895,9 @@ def load_json(file_name):
return json.load(f) return json.load(f)
def _sanitize_json_data(data: Any) -> Any:
"""Recursively sanitize all string values in data structure for safe UTF-8 encoding
Handles all JSON-serializable types including:
- Dictionary keys and values
- Lists and tuples (preserves type)
- Nested structures
- Strings at any level
Args:
data: Data to sanitize (dict, list, tuple, str, or other types)
Returns:
Sanitized data with all strings cleaned of problematic characters
"""
if isinstance(data, dict):
# Sanitize both keys and values
return {
_sanitize_string_for_json(k)
if isinstance(k, str)
else k: _sanitize_json_data(v)
for k, v in data.items()
}
elif isinstance(data, (list, tuple)):
# Handle both lists and tuples, preserve original type
sanitized = [_sanitize_json_data(item) for item in data]
return type(data)(sanitized)
elif isinstance(data, str):
return sanitize_text_for_encoding(data, replacement_char="")
else:
# Numbers, booleans, None, etc. - return as-is
return data
def write_json(json_obj, file_name): def write_json(json_obj, file_name):
# Sanitize data before writing to prevent UTF-8 encoding errors
sanitized_obj = _sanitize_json_data(json_obj)
with open(file_name, "w", encoding="utf-8") as f: with open(file_name, "w", encoding="utf-8") as f:
json.dump(sanitized_obj, f, indent=2, ensure_ascii=False) json.dump(json_obj, f, indent=2, ensure_ascii=False)
class TokenizerInterface(Protocol): class TokenizerInterface(Protocol):
@ -1852,7 +1784,7 @@ def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str:
- Filter out short numeric-only text (length < 3 and only digits/dots) - Filter out short numeric-only text (length < 3 and only digits/dots)
- remove_inner_quotes = True - remove_inner_quotes = True
remove Chinese quotes remove Chinese quotes
remove English quotes in and around chinese remove English queotes in and around chinese
Convert non-breaking spaces to regular spaces Convert non-breaking spaces to regular spaces
Convert narrow non-breaking spaces after non-digits to regular spaces Convert narrow non-breaking spaces after non-digits to regular spaces
@ -2533,157 +2465,80 @@ async def process_chunks_unified(
return final_chunks return final_chunks
def truncate_entity_source_id(chunk_ids: set, entity_name: str) -> set:
"""Limit chunk_ids, for entities that appear a HUGE no of times (To not break VDB hard upper limits)"""
already_len: int = len(chunk_ids)
def normalize_source_ids_limit_method(method: str | None) -> str: max_chunk_ids_per_entity = get_env_value("MAX_CHUNK_IDS_PER_ENTITY", DEFAULT_MAX_CHUNK_IDS_PER_ENTITY, int)
"""Normalize the source ID limiting strategy and fall back to default when invalid."""
if not method: if already_len >= max_chunk_ids_per_entity:
return DEFAULT_SOURCE_IDS_LIMIT_METHOD
normalized = method.upper()
if normalized not in VALID_SOURCE_IDS_LIMIT_METHODS:
logger.warning( logger.warning(
"Unknown SOURCE_IDS_LIMIT_METHOD '%s', falling back to %s", f"Chunk Ids already exceeds {max_chunk_ids_per_entity } for {entity_name}, "
method, f"current size: {already_len} entries."
DEFAULT_SOURCE_IDS_LIMIT_METHOD,
) )
return DEFAULT_SOURCE_IDS_LIMIT_METHOD
truncated_chunk_ids = set(list(chunk_ids)[0:max_chunk_ids_per_entity ])
return normalized return truncated_chunk_ids
def merge_source_ids( def build_file_path(already_file_paths, data_list, target):
existing_ids: Iterable[str] | None, new_ids: Iterable[str] | None """Build file path string with UTF-8 byte length limit and deduplication
) -> list[str]:
"""Merge two iterables of source IDs while preserving order and removing duplicates."""
merged: list[str] = []
seen: set[str] = set()
for sequence in (existing_ids, new_ids):
if not sequence:
continue
for source_id in sequence:
if not source_id:
continue
if source_id not in seen:
seen.add(source_id)
merged.append(source_id)
return merged
def apply_source_ids_limit(
source_ids: Sequence[str],
limit: int,
method: str,
*,
identifier: str | None = None,
) -> list[str]:
"""Apply a limit strategy to a sequence of source IDs."""
if limit <= 0:
return []
source_ids_list = list(source_ids)
if len(source_ids_list) <= limit:
return source_ids_list
normalized_method = normalize_source_ids_limit_method(method)
if normalized_method == SOURCE_IDS_LIMIT_METHOD_FIFO:
truncated = source_ids_list[-limit:]
else: # IGNORE_NEW
truncated = source_ids_list[:limit]
if identifier and len(truncated) < len(source_ids_list):
logger.debug(
"Source_id truncated: %s | %s keeping %s of %s entries",
identifier,
normalized_method,
len(truncated),
len(source_ids_list),
)
return truncated
def compute_incremental_chunk_ids(
existing_full_chunk_ids: list[str],
old_chunk_ids: list[str],
new_chunk_ids: list[str],
) -> list[str]:
"""
Compute incrementally updated chunk IDs based on changes.
This function applies delta changes (additions and removals) to an existing
list of chunk IDs while maintaining order and ensuring deduplication.
Delta additions from new_chunk_ids are placed at the end.
Args: Args:
existing_full_chunk_ids: Complete list of existing chunk IDs from storage already_file_paths: List of existing file paths
old_chunk_ids: Previous chunk IDs from source_id (chunks being replaced) data_list: List of data items containing file_path
new_chunk_ids: New chunk IDs from updated source_id (chunks being added) target: Target name for logging warnings
Returns: Returns:
Updated list of chunk IDs with deduplication str: Combined file paths separated by GRAPH_FIELD_SEP
Example:
>>> existing = ['chunk-1', 'chunk-2', 'chunk-3']
>>> old = ['chunk-1', 'chunk-2']
>>> new = ['chunk-2', 'chunk-4']
>>> compute_incremental_chunk_ids(existing, old, new)
['chunk-3', 'chunk-2', 'chunk-4']
""" """
# Calculate changes # set: deduplication
chunks_to_remove = set(old_chunk_ids) - set(new_chunk_ids) file_paths_set = {fp for fp in already_file_paths if fp}
chunks_to_add = set(new_chunk_ids) - set(old_chunk_ids)
# Apply changes to full chunk_ids # string: filter empty value and keep file order in already_file_paths
# Step 1: Remove chunks that are no longer needed file_paths = GRAPH_FIELD_SEP.join(fp for fp in already_file_paths if fp)
updated_chunk_ids = [
cid for cid in existing_full_chunk_ids if cid not in chunks_to_remove
]
# Step 2: Add new chunks (preserving order from new_chunk_ids) # Check if initial file_paths already exceeds byte length limit
# Note: 'cid not in updated_chunk_ids' check ensures deduplication if len(file_paths.encode("utf-8")) >= DEFAULT_MAX_FILE_PATH_LENGTH:
for cid in new_chunk_ids: logger.warning(
if cid in chunks_to_add and cid not in updated_chunk_ids: f"Initial file_paths already exceeds {DEFAULT_MAX_FILE_PATH_LENGTH} bytes for {target}, "
updated_chunk_ids.append(cid) f"current size: {len(file_paths.encode('utf-8'))} bytes"
)
return updated_chunk_ids # ignored file_paths
file_paths_ignore = ""
# add file_paths
for dp in data_list:
cur_file_path = dp.get("file_path")
# empty
if not cur_file_path:
continue
# skip duplicate item
if cur_file_path in file_paths_set:
continue
# add
file_paths_set.add(cur_file_path)
def subtract_source_ids( # check the UTF-8 byte length
source_ids: Iterable[str], new_addition = GRAPH_FIELD_SEP + cur_file_path if file_paths else cur_file_path
ids_to_remove: Collection[str], if (
) -> list[str]: len(file_paths.encode("utf-8")) + len(new_addition.encode("utf-8"))
"""Remove a collection of IDs from an ordered iterable while preserving order.""" < DEFAULT_MAX_FILE_PATH_LENGTH - 5
):
# append
file_paths += new_addition
else:
# ignore
file_paths_ignore += GRAPH_FIELD_SEP + cur_file_path
removal_set = set(ids_to_remove) if file_paths_ignore:
if not removal_set: logger.warning(
return [source_id for source_id in source_ids if source_id] f"File paths exceed {DEFAULT_MAX_FILE_PATH_LENGTH} bytes for {target}, "
f"ignoring file path: {file_paths_ignore}"
return [ )
source_id return file_paths
for source_id in source_ids
if source_id and source_id not in removal_set
]
def make_relation_chunk_key(src: str, tgt: str) -> str:
"""Create a deterministic storage key for relation chunk tracking."""
return GRAPH_FIELD_SEP.join(sorted((src, tgt)))
def parse_relation_chunk_key(key: str) -> tuple[str, str]:
"""Parse a relation chunk storage key back into its entity pair."""
parts = key.split(GRAPH_FIELD_SEP)
if len(parts) != 2:
raise ValueError(f"Invalid relation chunk key: {key}")
return parts[0], parts[1]
def generate_track_id(prefix: str = "upload") -> str: def generate_track_id(prefix: str = "upload") -> str:
@ -2773,9 +2628,9 @@ def fix_tuple_delimiter_corruption(
record, record,
) )
# Fix: <X|#|> -> <|#|>, <|#|Y> -> <|#|>, <X|#|Y> -> <|#|>, <||#||> -> <|#|> (one extra characters outside pipes) # Fix: <X|#|> -> <|#|>, <|#|Y> -> <|#|>, <X|#|Y> -> <|#|>, <||#||> -> <|#|>, <||#> -> <|#|> (one extra characters outside pipes)
record = re.sub( record = re.sub(
rf"<.?\|{escaped_delimiter_core}\|.?>", rf"<.?\|{escaped_delimiter_core}\|*?>",
tuple_delimiter, tuple_delimiter,
record, record,
) )
@ -2795,6 +2650,7 @@ def fix_tuple_delimiter_corruption(
) )
# Fix: <|#| -> <|#|>, <|#|| -> <|#|> (missing closing >) # Fix: <|#| -> <|#|>, <|#|| -> <|#|> (missing closing >)
record = re.sub( record = re.sub(
rf"<\|{escaped_delimiter_core}\|+(?!>)", rf"<\|{escaped_delimiter_core}\|+(?!>)",
tuple_delimiter, tuple_delimiter,
@ -2808,13 +2664,6 @@ def fix_tuple_delimiter_corruption(
record, record,
) )
# Fix: <||#> -> <|#|> (double pipe at start, missing pipe at end)
record = re.sub(
rf"<\|+{escaped_delimiter_core}>",
tuple_delimiter,
record,
)
# Fix: <|| -> <|#|> # Fix: <|| -> <|#|>
record = re.sub( record = re.sub(
r"<\|\|(?!>)", r"<\|\|(?!>)",